Skip to content

Commit

Permalink
fix secretflow#811 for Insecure PackedB2A
Browse files Browse the repository at this point in the history
  • Loading branch information
fionser committed Aug 14, 2024
1 parent 9283cc1 commit 80dabb1
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 94 deletions.
155 changes: 61 additions & 94 deletions libspu/mpc/cheetah/ot/basic_ot_prot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,122 +68,89 @@ NdArrayRef BasicOTProtocols::B2A(const NdArrayRef &inp) {
return PackedB2A(inp);
}

// Convert the packed boolean shares to arithmetic share
// Input x in Z2k is the packed of b-bits for 1 <= b <= k.
// That is x0, x1, ..., x{b-1}
// Output y in Z2k such that y = \sum_i x{i}*2^i mod 2^k
//
// Ref: The ABY paper https://encrypto.de/papers/DSZ15.pdf Section E
NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) {
const auto *share_t = inp.eltype().as<BShrTy>();
auto field = inp.eltype().as<Ring2k>()->field();
const int64_t n = inp.numel();
size_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits();
if (n >= 8) {
// 8bits-align for a larger input
nbits = (nbits + 7) / 8 * 8;
}
SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field));

auto rand_bits = DISPATCH_ALL_FIELDS(field, [&]() {
if ((nbits & 7) or (n * inp.elsize()) & 7) {
// The SseTranspose requires the #rows and #columns is multiple of 8.
// Thus, we call the less efficient RandBits on margin cases.
return RandBits(field, {static_cast<int64_t>(n * nbits)});
}
const int64_t ring_width = SizeOf(field) * 8;

// More efficient randbits that ultilize collapse COTs.
int64_t B = nbits;
auto r = ring_randbit(field, {n * B}).as(makeType<BShrTy>(field, 1));
const int64_t numl = r.numel();
const int64_t n = inp.numel();
const int64_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits();
const int64_t numel = n * nbits;

NdArrayRef oup = ring_zeros(field, r.shape());
NdArrayRef cot_oup = ring_zeros(field, {numel});
NdArrayRef arith_oup = ring_zeros(field, inp.shape());
DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;
auto input = NdArrayView<const u2k>(r);
auto output = absl::MakeSpan(&oup.at<u2k>(0), numl);
SPU_ENFORCE(oup.isCompact());
auto input = NdArrayView<const u2k>(inp);
auto cot_output = absl::MakeSpan(&cot_oup.at<u2k>(0), cot_oup.numel());

if (Rank() == 0) {
std::vector<u2k> corr_data(numl);
// NOTE(lwj): Masking to make sure there is only single bit.
for (int64_t i = 0; i < numl; ++i) {
// corr=-2*xi
corr_data[i] = -((input[i] & 1) << 1);
std::vector<u2k> corr_data(numel);

for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
auto msk = makeBitsMask<u2k>(ring_width - k);
for (int64_t j = 0; j < n; ++j) {
// corr[k] = -2*x0_k
corr_data[i + j] = -2 * ((input[j] >> k) & 1);
corr_data[i + j] &= msk;
}
}
// Run the multiple COT in the collapse mode.
// That is, the i-th COT returns output of `nbits - i` bits.
ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), output,
/*bw*/ nbits, /*num_level*/ nbits);
ferret_sender_->Flush();
// That is, the k-th COT returns output of `ring_width - k` bits.
//
// The k-th COT gives the arithmetic share of the k-th bit of the input
// according to x_0 ^ x_1 = x_0 + x_1 - 2 * x_0 * x_1
ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), cot_output,
/*bw*/ ring_width,
/*num_level*/ nbits);

for (int64_t i = 0; i < numl; ++i) {
output[i] = (input[i] & 1) - output[i];
ferret_sender_->Flush();
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
cot_output[i + j] = ((input[j] >> k) & 1) - cot_output[i + j];
}
}
} else {
std::vector<uint8_t> choices(numl);
for (int64_t i = 0; i < numl; ++i) {
choices[i] = static_cast<uint8_t>(input[i] & 1);
}
ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), output,
nbits, nbits);

for (int64_t i = 0; i < numl; ++i) {
output[i] = (input[i] & 1) + output[i];
// choice[k] is the k-th bit x1_k
std::vector<uint8_t> choices(numel);
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
choices[i + j] = (input[j] >> k) & 1;
}
}
}

// oup.shape B x (n * T)
std::vector<uint8_t> tmp(B * n * inp.elsize());

// bit matrix transpose
SseTranspose(oup.data<uint8_t>(), tmp.data(), B, n * inp.elsize());
ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), cot_output,
ring_width, nbits);

std::copy_n(tmp.data(), tmp.size(), oup.data<uint8_t>());
return oup;
});

// convert the bit form to integer form
auto rand = [&](NdArrayRef _bits) {
SPU_ENFORCE(_bits.isCompact(), "need compact input");
const int64_t n = _bits.numel() / nbits;
// init as all 0s.
auto iform = ring_zeros(field, inp.shape());
DISPATCH_ALL_FIELDS(field, [&]() {
auto bits = NdArrayView<const ring2k_t>(_bits);
auto digit = NdArrayView<ring2k_t>(iform);
for (int64_t i = 0; i < n; ++i) {
// LSB is bits[0]; MSB is bits[nbits - 1]
// We iterate the bits in reversed order
const size_t offset = i * nbits;
digit[i] = 0;
for (size_t j = nbits; j > 0; --j) {
digit[i] = (digit[i] << 1) | (bits[offset + j - 1] & 1);
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
cot_output[i + j] = ((input[j] >> k) & 1) + cot_output[i + j];
}
}
});
return iform;
}(rand_bits);

// open c = x ^ r
auto opened = OpenShare(ring_xor(inp, rand), ReduceOp::XOR, nbits, conn_);
}

// compute c + (1 - 2*c)*<r>
NdArrayRef oup = ring_zeros(field, inp.shape());
DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;
int rank = Rank();
auto xr = NdArrayView<const u2k>(rand_bits);
auto xc = NdArrayView<const u2k>(opened);
auto xo = NdArrayView<ring2k_t>(oup);

for (int64_t i = 0; i < n; ++i) {
const size_t offset = i * nbits;
u2k this_elt = xc[i];
for (size_t j = 0; j < nbits; ++j, this_elt >>= 1) {
u2k c_ij = this_elt & 1;
ring2k_t one_bit = (1 - c_ij * 2) * xr[offset + j];
if (rank == 0) {
one_bit += c_ij;
}
xo[i] += (one_bit << j);
// <x> = \sum_k 2^k * <x_k>
// where <x_k> is the arithmetic share of the k-th bit
NdArrayView<u2k> arith(arith_oup);
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
arith[j] += (cot_output[i + j] << k);
}
}
});
return oup;

return arith_oup;
}

// Math:
Expand Down
56 changes: 56 additions & 0 deletions libspu/mpc/cheetah/ot/emp/ferret_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,60 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) {
});
}

TEST_P(FerretCOTTest, COT_Collapse) {
size_t kWorldSize = 2;
int64_t n = 8;
auto field = GetParam();

const auto bw = SizeOf(field) * 8;
const int level = bw;

// generate random choices and correlation
const auto _correlation = ring_rand(field, {static_cast<int64_t>(n * level)});
const auto N = _correlation.numel();

NdArrayRef oup1 = ring_zeros(field, _correlation.shape());
NdArrayRef oup2 = ring_zeros(field, _correlation.shape());

std::vector<uint8_t> choices(N, 1);

DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;

auto out1_span = absl::MakeSpan(&oup1.at<u2k>(0), N);
auto out2_span = absl::MakeSpan(&oup2.at<u2k>(0), N);

NdArrayView<u2k> correlation(_correlation);

utils::simulate(kWorldSize, [&](std::shared_ptr<yacl::link::Context> ctx) {
auto conn = std::make_shared<Communicator>(ctx);
int rank = ctx->Rank();

EmpFerretOt ferret(conn, rank == 0);
if (rank == 0) {
ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw,
level);
ferret.Flush();

} else {
ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw,
level);
}
});

// Sample-major order
// n || n || n || .... || n
// k=level||k=level - 1||k=level - 2|| ....
for (int64_t i = 0; i < N; i += n) {
const auto cur_bw = bw - (i / n);
const auto mask = makeMask<ring2k_t>(cur_bw);
for (int64_t j = 0; j < n; ++j) {
ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask;
ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask;

ASSERT_EQ(c, e);
}
}
});
}
} // namespace spu::mpc::cheetah::test
65 changes: 65 additions & 0 deletions libspu/mpc/cheetah/ot/yacl/ferret_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,69 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) {
});
}

template <typename T>
T makeMask(int bw) {
if (bw == sizeof(T) * 8) {
return static_cast<T>(-1);
}
return (static_cast<T>(1) << bw) - 1;
}

TEST_P(FerretCOTTest, COT_Collapse) {
size_t kWorldSize = 2;
int64_t n = 8;
auto field = std::get<0>(GetParam());
auto use_ss = std::get<1>(GetParam());

const auto bw = SizeOf(field) * 8;
const int level = bw;

// generate random choices and correlation
const auto _correlation = ring_rand(field, {static_cast<int64_t>(n * level)});
const auto N = _correlation.numel();

NdArrayRef oup1 = ring_zeros(field, _correlation.shape());
NdArrayRef oup2 = ring_zeros(field, _correlation.shape());

std::vector<uint8_t> choices(N, 1);

DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;

auto out1_span = absl::MakeSpan(&oup1.at<u2k>(0), N);
auto out2_span = absl::MakeSpan(&oup2.at<u2k>(0), N);

NdArrayView<u2k> correlation(_correlation);

utils::simulate(kWorldSize, [&](std::shared_ptr<yacl::link::Context> ctx) {
auto conn = std::make_shared<Communicator>(ctx);
int rank = ctx->Rank();

YaclFerretOt ferret(conn, rank == 0, use_ss);
if (rank == 0) {
ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw,
level);
ferret.Flush();

} else {
ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw,
level);
}
});

// Sample-major order
// n || n || n || .... || n
// k=level||k=level - 1||k=level - 2|| ....
for (int64_t i = 0; i < N; i += n) {
const auto cur_bw = bw - (i / n);
const auto mask = makeMask<ring2k_t>(cur_bw);
for (int64_t j = 0; j < n; ++j) {
ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask;
ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask;

ASSERT_EQ(c, e);
}
}
});
}
} // namespace spu::mpc::cheetah::test

0 comments on commit 80dabb1

Please sign in to comment.