Skip to content

Commit

Permalink
feat(compiler): support woppbs in simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Sep 25, 2023
1 parent c640b4b commit 5a59261
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void sim_wop_pbs_crt(
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size);
uint32_t polynomial_size, uint32_t glwe_dim);

void sim_encode_expand_lut_for_boostrap(
uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ struct EncodeLutForCrtWopPBSOpPattern
encodeOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});

auto dynamicResultType =
toDynamicTensorType(encodeOp.getResult().getType());
auto dynamicLutType =
toDynamicTensorType(encodeOp.getInputLookupTable().getType());

mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
encodeOp.getLoc(), dynamicResultType, outputBuffer);
mlir::Value castedLUT = rewriter.create<mlir::tensor::CastOp>(
encodeOp.getLoc(), dynamicLutType, adaptor.getInputLookupTable());

auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, encodeOp.getLoc(), encodeOp.getCrtDecompositionAttr());
auto crtBitsValue = mlir::concretelang::globalMemrefFromArrayAttr(
Expand All @@ -213,20 +223,18 @@ struct EncodeLutForCrtWopPBSOpPattern
if (insertForwardDeclaration(
encodeOp, rewriter, funcName,
rewriter.getFunctionType(
{encodeOp.getResult().getType(),
encodeOp.getInputLookupTable().getType(),
crtDecompValue.getType(), crtBitsValue.getType(),
rewriter.getIntegerType(32), rewriter.getIntegerType(1)},
{dynamicResultType, dynamicLutType, crtDecompValue.getType(),
crtBitsValue.getType(), rewriter.getIntegerType(32),
rewriter.getIntegerType(1)},
{}))
.failed()) {
return mlir::failure();
}

rewriter.create<mlir::func::CallOp>(
encodeOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(),
crtDecompValue, crtBitsValue, modulusProductCst,
isSignedCst}));
mlir::ValueRange({castedOutputBuffer, castedLUT, crtDecompValue,
crtBitsValue, modulusProductCst, isSignedCst}));

rewriter.replaceOp(encodeOp, outputBuffer);

Expand Down Expand Up @@ -259,13 +267,18 @@ struct EncodePlaintextWithCrtOpPattern
epOp.getResult().getType().cast<mlir::RankedTensorType>(),
mlir::ValueRange{});

auto dynamicResultType = toDynamicTensorType(epOp.getResult().getType());

mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
epOp.getLoc(), dynamicResultType, outputBuffer);

auto ModsValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, epOp.getLoc(), epOp.getModsAttr());

if (insertForwardDeclaration(
epOp, rewriter, funcName,
rewriter.getFunctionType(
{epOp.getResult().getType(), epOp.getInput().getType(),
{dynamicResultType, epOp.getInput().getType(),
ModsValue.getType(), rewriter.getI64Type()},
{}))
.failed()) {
Expand All @@ -274,8 +287,8 @@ struct EncodePlaintextWithCrtOpPattern

rewriter.create<mlir::func::CallOp>(
epOp.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange(
{outputBuffer, adaptor.getInput(), ModsValue, modsProductCst}));
mlir::ValueRange({castedOutputBuffer, adaptor.getInput(), ModsValue,
modsProductCst}));

rewriter.replaceOp(epOp, outputBuffer);

Expand Down Expand Up @@ -311,6 +324,22 @@ struct WopPBSGLWEOpPattern
.cast<mlir::RankedTensorType>(),
mlir::ValueRange{});

auto dynamicResultType = toDynamicTensorType(this->getTypeConverter()
->convertType(resultType)
.cast<mlir::TensorType>());
auto dynamicInputType = toDynamicTensorType(this->getTypeConverter()
->convertType(inputType)
.cast<mlir::TensorType>());
auto dynamicLutType =
toDynamicTensorType(wopPbs.getLookupTable().getType());

mlir::Value castedOutputBuffer = rewriter.create<mlir::tensor::CastOp>(
wopPbs.getLoc(), dynamicResultType, outputBuffer);
mlir::Value castedCiphertexts = rewriter.create<mlir::tensor::CastOp>(
wopPbs.getLoc(), dynamicInputType, adaptor.getCiphertexts());
mlir::Value castedLut = rewriter.create<mlir::tensor::CastOp>(
wopPbs.getLoc(), dynamicLutType, adaptor.getLookupTable());

auto lweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getInputLweDim(), 32);
auto cbsLevelCountCst = rewriter.create<mlir::arith::ConstantIntOp>(
Expand All @@ -331,16 +360,17 @@ struct WopPBSGLWEOpPattern
wopPbs.getLoc(), adaptor.getPksk().getBaseLog(), 32);
auto polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getPksk().getOutputPolySize(), 32);
auto glweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
wopPbs.getLoc(), adaptor.getBsk().getGlweDim(), 32);

auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr(
rewriter, wopPbs.getLoc(), wopPbs.getCrtDecompositionAttr());

if (insertForwardDeclaration(
wopPbs, rewriter, funcName,
rewriter.getFunctionType(
{this->getTypeConverter()->convertType(resultType),
this->getTypeConverter()->convertType(inputType),
wopPbs.getLookupTable().getType(), crtDecompValue.getType(),
{dynamicResultType, dynamicInputType, dynamicLutType,
crtDecompValue.getType(), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
rewriter.getIntegerType(32), rewriter.getIntegerType(32),
Expand All @@ -353,11 +383,11 @@ struct WopPBSGLWEOpPattern

rewriter.create<mlir::func::CallOp>(
wopPbs.getLoc(), funcName, mlir::TypeRange{},
mlir::ValueRange({outputBuffer, adaptor.getCiphertexts(),
adaptor.getLookupTable(), crtDecompValue, lweDimCst,
cbsLevelCountCst, cbsBaseLogCst, kskLevelCountCst,
kskBaseLogCst, bskLevelCountCst, bskBaseLogCst,
fpkskLevelCountCst, fpkskBaseLogCst, polySizeCst}));
mlir::ValueRange({castedOutputBuffer, castedCiphertexts, castedLut,
crtDecompValue, lweDimCst, cbsLevelCountCst,
cbsBaseLogCst, kskLevelCountCst, kskBaseLogCst,
bskLevelCountCst, bskBaseLogCst, fpkskLevelCountCst,
fpkskBaseLogCst, polySizeCst, glweDimCst}));

rewriter.replaceOp(wopPbs, outputBuffer);

Expand Down Expand Up @@ -542,7 +572,8 @@ void SimulateTFHEPass::runOnOperation() {

target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::func::CallOp, mlir::memref::GetGlobalOp,
mlir::bufferization::AllocTensorOp, mlir::tensor::CastOp>();
mlir::memref::CastOp, mlir::bufferization::AllocTensorOp,
mlir::tensor::CastOp>();
// Make sure that no ops from `TFHE` remain after the lowering
target.addIllegalDialect<TFHE::TFHEDialect>();

Expand Down
66 changes: 64 additions & 2 deletions compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,70 @@ void sim_wop_pbs_crt(
uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log,
uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count,
uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log,
uint32_t polynomial_size) {
// TODO
uint32_t polynomial_size, uint32_t glwe_dim) {

// Check number of blocks
assert(out_size == in_size && out_size == crt_decomp_size);

uint64_t log_poly_size =
static_cast<uint64_t>(ceil(log2(static_cast<double>(polynomial_size))));

// Compute the numbers of bits to extract for each block and the total one.
uint64_t total_number_of_bits_per_block = 0;
auto number_of_bits_per_block = new uint64_t[crt_decomp_size]();
for (uint64_t i = 0; i < crt_decomp_size; i++) {
uint64_t modulus = crt_decomp_aligned[i + crt_decomp_offset];
uint64_t nb_bit_to_extract =
static_cast<uint64_t>(ceil(log2(static_cast<double>(modulus))));
number_of_bits_per_block[i] = nb_bit_to_extract;

total_number_of_bits_per_block += nb_bit_to_extract;
}

// Create the buffer of ciphertexts for storing the total number of bits to
// extract.
// The extracted bit should be in the following order:
//
// [msb(m%crt[n-1])..lsb(m%crt[n-1])...msb(m%crt[0])..lsb(m%crt[0])] where n
// is the size of the crt decomposition
auto extract_bits_output_buffer =
new uint64_t[total_number_of_bits_per_block]{0};

// Extraction of each bit for each block
for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0;
extract_bits_output_offset += number_of_bits_per_block[i--]) {
auto nb_bits_to_extract = number_of_bits_per_block[i];

size_t delta_log = 64 - nb_bits_to_extract;

auto in_block = in_aligned[in_offset + i];

// trick ( ct - delta/2 + delta/2^4 )
uint64_t sub = (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 1)) -
(uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 5));
in_block -= sub;

simulation_extract_bit_lwe_ciphertext_u64(
&extract_bits_output_buffer[extract_bits_output_offset], in_block,
delta_log, nb_bits_to_extract, log_poly_size, glwe_dim, lwe_small_dim,
ksk_base_log, ksk_level_count, bsk_base_log, bsk_level_count, 64, 128);
}

size_t ct_in_count = total_number_of_bits_per_block;
size_t lut_size = 1 << ct_in_count;
size_t ct_out_count = out_size;
size_t lut_count = ct_out_count;

assert(lut_ct_size0 == lut_count);
assert(lut_ct_size1 == lut_size);

// Vertical packing
simulation_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
extract_bits_output_buffer, out_aligned + out_offset, ct_in_count,
ct_out_count, lut_size, lut_count, lut_ct_aligned + lut_ct_offset,
glwe_dim, log_poly_size, lwe_small_dim, bsk_level_count, bsk_base_log,
cbs_level_count, cbs_base_log, fpksk_level_count, fpksk_base_log, 64,
128);
}

uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ def compile_run_assert(
).reshape((4, 4)),
id="matul_chain_with_crt",
),
pytest.param(
"""
func.func @main(%arg0: !FHE.eint<14>, %arg1: tensor<16384xi64>) -> !FHE.eint<14> {
%cst = arith.constant 15 : i15
%v = "FHE.add_eint_int"(%arg0, %cst): (!FHE.eint<14>, i15) -> (!FHE.eint<14>)
%1 = "FHE.apply_lookup_table"(%v, %arg1): (!FHE.eint<14>, tensor<16384xi64>) -> (!FHE.eint<14>)
return %1: !FHE.eint<14>
}
""",
(
81,
np.array(range(16384), dtype=np.uint64),
),
96,
id="add_lut_crt",
)
]

end_to_end_parallel_fixture = [
Expand Down

0 comments on commit 5a59261

Please sign in to comment.