Skip to content

Commit

Permalink
add utilities for copying partially computed sumcheck results
Browse files Browse the repository at this point in the history
  • Loading branch information
rnburn committed Jan 22, 2025
1 parent 6d62ae4 commit 5f31fee
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 0 deletions.
26 changes: 26 additions & 0 deletions sxt/proof/sumcheck/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@ sxt_cc_component(
],
)

sxt_cc_component(
name = "mle_utility",
impl_deps = [
"//sxt/base/container:span_utility",
"//sxt/base/device:memory_utility",
"//sxt/base/device:property",
"//sxt/base/num:divide_up",
"//sxt/base/num:ceil_log2",
"//sxt/memory/management:managed_array",
"//sxt/scalar25/type:element",
],
test_deps = [
"//sxt/base/device:stream",
"//sxt/base/device:synchronization",
"//sxt/base/test:unit_test",
"//sxt/memory/management:managed_array",
"//sxt/memory/resource:managed_device_resource",
"//sxt/scalar25/type:element",
"//sxt/scalar25/type:literal",
],
deps = [
"//sxt/base/container:span",
"//sxt/memory/management:managed_array_fwd",
],
)

sxt_cc_component(
name = "workspace",
with_test = False,
Expand Down
93 changes: 93 additions & 0 deletions sxt/proof/sumcheck/mle_utility.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/mle_utility.h"

#include <algorithm>
#include <iostream>

#include "sxt/base/container/span_utility.h"
#include "sxt/base/device/memory_utility.h"
#include "sxt/base/device/property.h"
#include "sxt/base/device/stream.h"
#include "sxt/base/error/assert.h"
#include "sxt/base/num/ceil_log2.h"
#include "sxt/base/num/divide_up.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/scalar25/type/element.h"

namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
// copy_partial_mles
//--------------------------------------------------------------------------------------------------
void copy_partial_mles(memmg::managed_array<s25t::element>& partial_mles, basdv::stream& stream,
basct::cspan<s25t::element> mles, unsigned n, unsigned a,
unsigned b) noexcept {
auto num_variables = std::max(basn::ceil_log2(n), 1);
auto mid = 1u << (num_variables - 1u);
auto num_mles = mles.size() / n;
auto part1_size = b - a;
SXT_DEBUG_ASSERT(a < b && b <= n);
auto ap = std::min(mid + a, n);
auto bp = std::min(mid + b, n);
auto part2_size = bp - ap;

// resize array
auto partial_length = part1_size + part2_size;
partial_mles.resize(partial_length * num_mles);

// copy data
for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) {
// first part
auto src = mles.subspan(n * mle_index + a, part1_size);
auto dst = basct::subspan(partial_mles, partial_length * mle_index, part1_size);
basdv::async_copy_host_to_device(dst, src, stream);

// second part
src = mles.subspan(n * mle_index + ap, part2_size);
dst = basct::subspan(partial_mles, partial_length * mle_index + part1_size, part2_size);
if (!src.empty()) {
basdv::async_copy_host_to_device(dst, src, stream);
}
}
}

//--------------------------------------------------------------------------------------------------
// copy_folded_mles
//--------------------------------------------------------------------------------------------------
void copy_folded_mles(basct::span<s25t::element> host_mles, basdv::stream& stream,
basct::cspan<s25t::element> device_mles, unsigned np, unsigned a,
unsigned b) noexcept {
auto num_mles = host_mles.size() / np;
auto slice_n = device_mles.size() / num_mles;
auto slice_np = b - a;
SXT_DEBUG_ASSERT(host_mles.size() == num_mles * np && device_mles.size() == num_mles * slice_n &&
b <= np);
for (unsigned mle_index = 0; mle_index < num_mles; ++mle_index) {
auto src = device_mles.subspan(mle_index * slice_n, slice_np);
auto dst = host_mles.subspan(mle_index * np + a, slice_np);
basdv::async_copy_device_to_host(dst, src, stream);
}
}

//--------------------------------------------------------------------------------------------------
// get_gpu_memory_fraction
//--------------------------------------------------------------------------------------------------
double get_gpu_memory_fraction(basct::cspan<s25t::element> mles) noexcept {
auto total_memory = static_cast<double>(basdv::get_total_device_memory());
return static_cast<double>(mles.size() * sizeof(s25t::element)) / total_memory;
}
} // namespace sxt::prfsk
48 changes: 48 additions & 0 deletions sxt/proof/sumcheck/mle_utility.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "sxt/base/container/span.h"
#include "sxt/memory/management/managed_array_fwd.h"

namespace sxt::basdv {
class stream;
}
namespace sxt::s25t {
class element;
}

namespace sxt::prfsk {
//--------------------------------------------------------------------------------------------------
// copy_partial_mles
//--------------------------------------------------------------------------------------------------
void copy_partial_mles(memmg::managed_array<s25t::element>& partial_mles, basdv::stream& stream,
basct::cspan<s25t::element> mles, unsigned n, unsigned a,
unsigned b) noexcept;

//--------------------------------------------------------------------------------------------------
// copy_folded_mles
//--------------------------------------------------------------------------------------------------
void copy_folded_mles(basct::span<s25t::element> host_mles, basdv::stream& stream,
basct::cspan<s25t::element> device_mles, unsigned np, unsigned a,
unsigned b) noexcept;

//--------------------------------------------------------------------------------------------------
// get_gpu_memory_fraction
//--------------------------------------------------------------------------------------------------
double get_gpu_memory_fraction(basct::cspan<s25t::element> mles) noexcept;
} // namespace sxt::prfsk
94 changes: 94 additions & 0 deletions sxt/proof/sumcheck/mle_utility.t.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
*
* Copyright 2025-present Space and Time Labs, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "sxt/proof/sumcheck/mle_utility.h"

#include <vector>

#include "sxt/base/device/stream.h"
#include "sxt/base/device/synchronization.h"
#include "sxt/base/test/unit_test.h"
#include "sxt/memory/management/managed_array.h"
#include "sxt/memory/resource/managed_device_resource.h"
#include "sxt/scalar25/type/element.h"
#include "sxt/scalar25/type/literal.h"

using namespace sxt;
using namespace sxt::prfsk;
using s25t::operator""_s25;

TEST_CASE("we can copy a slice of mles to device memory") {
std::pmr::vector<s25t::element> mles{memr::get_managed_device_resource()};
memmg::managed_array<s25t::element> partial_mles{memr::get_managed_device_resource()};

basdv::stream stream;

SECTION("we can copy an mle with a single element") {
mles = {0x123_s25};
copy_partial_mles(partial_mles, stream, mles, 1, 0, 1);
basdv::synchronize_stream(stream);
memmg::managed_array<s25t::element> expected = {0x123_s25};
REQUIRE(partial_mles == expected);
}

SECTION("we can copy a slice of MLEs") {
mles = {0x1_s25, 0x2_s25, 0x3_s25, 0x4_s25, 0x5_s25, 0x6_s25};
copy_partial_mles(partial_mles, stream, mles, 3, 0, 1);
basdv::synchronize_stream(stream);
memmg::managed_array<s25t::element> expected = {0x1_s25, 0x3_s25, 0x4_s25, 0x6_s25};
REQUIRE(partial_mles == expected);
}
}

TEST_CASE("we can copy partially folded MLEs to the host") {
std::pmr::vector<s25t::element> device_mles{memr::get_managed_device_resource()};
std::vector<s25t::element> host_mles;

basdv::stream stream;

SECTION("we can copy a single element") {
device_mles = {0x123_s25};
host_mles.resize(1);
copy_folded_mles(host_mles, stream, device_mles, 1, 0, 1);
basdv::synchronize_stream(stream);
std::vector<s25t::element> expected = {0x123_s25};
REQUIRE(host_mles == expected);
}

SECTION("we can copy partially folded MLEs") {
device_mles = {0x123_s25, 0x456_s25};
host_mles.resize(4);
copy_folded_mles(host_mles, stream, device_mles, 2, 0, 1);
basdv::synchronize_stream(stream);
std::vector<s25t::element> expected = {0x123_s25, 0x0_s25, 0x456_s25, 0x0_s25};
REQUIRE(host_mles == expected);
}
}

TEST_CASE("we can query the fraction of device memory taken by MLEs") {
std::vector<s25t::element> mles;

SECTION("we handle the zero case") { REQUIRE(get_gpu_memory_fraction(mles) == 0.0); }

SECTION("the fractions doubles if the length of mles doubles") {
mles.resize(1);
auto f1 = get_gpu_memory_fraction(mles);
REQUIRE(f1 > 0);
mles.resize(2);
auto f2 = get_gpu_memory_fraction(mles);
REQUIRE(f2 == Catch::Approx(2 * f1));
}
}

0 comments on commit 5f31fee

Please sign in to comment.