From 5a592616e5f241c1528e9ed26a8f27c6b0eb0da7 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 25 Sep 2023 10:42:01 +0100 Subject: [PATCH] feat(compiler): support woppbs in simulation --- .../include/concretelang/Runtime/simulation.h | 2 +- .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 69 ++++++++++++++----- .../compiler/lib/Runtime/simulation.cpp | 66 +++++++++++++++++- .../compiler/tests/python/test_simulation.py | 16 +++++ 4 files changed, 131 insertions(+), 22 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h index a19298addd..74688876ea 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Runtime/simulation.h @@ -77,7 +77,7 @@ void sim_wop_pbs_crt( uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log, uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count, uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log, - uint32_t polynomial_size); + uint32_t polynomial_size, uint32_t glwe_dim); void sim_encode_expand_lut_for_boostrap( uint64_t *in_allocated, uint64_t *in_aligned, uint64_t in_offset, diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index dc98cec3f4..1df488ec12 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -205,6 +205,16 @@ struct EncodeLutForCrtWopPBSOpPattern encodeOp.getResult().getType().cast(), mlir::ValueRange{}); + auto dynamicResultType = + toDynamicTensorType(encodeOp.getResult().getType()); + auto dynamicLutType = + toDynamicTensorType(encodeOp.getInputLookupTable().getType()); + + mlir::Value castedOutputBuffer = rewriter.create( + encodeOp.getLoc(), dynamicResultType, outputBuffer); + mlir::Value castedLUT = rewriter.create( + encodeOp.getLoc(), dynamicLutType, adaptor.getInputLookupTable()); + auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr( rewriter, encodeOp.getLoc(), encodeOp.getCrtDecompositionAttr()); auto crtBitsValue = mlir::concretelang::globalMemrefFromArrayAttr( @@ -213,10 +223,9 @@ struct EncodeLutForCrtWopPBSOpPattern if (insertForwardDeclaration( encodeOp, rewriter, funcName, rewriter.getFunctionType( - {encodeOp.getResult().getType(), - encodeOp.getInputLookupTable().getType(), - crtDecompValue.getType(), crtBitsValue.getType(), - rewriter.getIntegerType(32), rewriter.getIntegerType(1)}, + {dynamicResultType, dynamicLutType, crtDecompValue.getType(), + crtBitsValue.getType(), rewriter.getIntegerType(32), + rewriter.getIntegerType(1)}, {})) .failed()) { return mlir::failure(); @@ -224,9 +233,8 @@ struct EncodeLutForCrtWopPBSOpPattern rewriter.create( encodeOp.getLoc(), funcName, mlir::TypeRange{}, - mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(), - crtDecompValue, crtBitsValue, modulusProductCst, - isSignedCst})); + mlir::ValueRange({castedOutputBuffer, castedLUT, crtDecompValue, + crtBitsValue, modulusProductCst, isSignedCst})); rewriter.replaceOp(encodeOp, outputBuffer); @@ -259,13 +267,18 @@ struct EncodePlaintextWithCrtOpPattern epOp.getResult().getType().cast(), mlir::ValueRange{}); + auto dynamicResultType = toDynamicTensorType(epOp.getResult().getType()); + + mlir::Value castedOutputBuffer = rewriter.create( + epOp.getLoc(), dynamicResultType, outputBuffer); + auto ModsValue = mlir::concretelang::globalMemrefFromArrayAttr( rewriter, epOp.getLoc(), epOp.getModsAttr()); if (insertForwardDeclaration( epOp, rewriter, funcName, rewriter.getFunctionType( - {epOp.getResult().getType(), epOp.getInput().getType(), + {dynamicResultType, epOp.getInput().getType(), ModsValue.getType(), rewriter.getI64Type()}, {})) .failed()) { @@ -274,8 +287,8 @@ struct EncodePlaintextWithCrtOpPattern rewriter.create( epOp.getLoc(), funcName, mlir::TypeRange{}, - mlir::ValueRange( - {outputBuffer, adaptor.getInput(), ModsValue, modsProductCst})); + mlir::ValueRange({castedOutputBuffer, adaptor.getInput(), ModsValue, + modsProductCst})); rewriter.replaceOp(epOp, outputBuffer); @@ -311,6 +324,22 @@ struct WopPBSGLWEOpPattern .cast(), mlir::ValueRange{}); + auto dynamicResultType = toDynamicTensorType(this->getTypeConverter() + ->convertType(resultType) + .cast()); + auto dynamicInputType = toDynamicTensorType(this->getTypeConverter() + ->convertType(inputType) + .cast()); + auto dynamicLutType = + toDynamicTensorType(wopPbs.getLookupTable().getType()); + + mlir::Value castedOutputBuffer = rewriter.create( + wopPbs.getLoc(), dynamicResultType, outputBuffer); + mlir::Value castedCiphertexts = rewriter.create( + wopPbs.getLoc(), dynamicInputType, adaptor.getCiphertexts()); + mlir::Value castedLut = rewriter.create( + wopPbs.getLoc(), dynamicLutType, adaptor.getLookupTable()); + auto lweDimCst = rewriter.create( wopPbs.getLoc(), adaptor.getPksk().getInputLweDim(), 32); auto cbsLevelCountCst = rewriter.create( @@ -331,6 +360,8 @@ struct WopPBSGLWEOpPattern wopPbs.getLoc(), adaptor.getPksk().getBaseLog(), 32); auto polySizeCst = rewriter.create( wopPbs.getLoc(), adaptor.getPksk().getOutputPolySize(), 32); + auto glweDimCst = rewriter.create( + wopPbs.getLoc(), adaptor.getBsk().getGlweDim(), 32); auto crtDecompValue = mlir::concretelang::globalMemrefFromArrayAttr( rewriter, wopPbs.getLoc(), wopPbs.getCrtDecompositionAttr()); @@ -338,9 +369,8 @@ struct WopPBSGLWEOpPattern if (insertForwardDeclaration( wopPbs, rewriter, funcName, rewriter.getFunctionType( - {this->getTypeConverter()->convertType(resultType), - this->getTypeConverter()->convertType(inputType), - wopPbs.getLookupTable().getType(), crtDecompValue.getType(), + {dynamicResultType, dynamicInputType, dynamicLutType, + crtDecompValue.getType(), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), @@ -353,11 +383,11 @@ struct WopPBSGLWEOpPattern rewriter.create( wopPbs.getLoc(), funcName, mlir::TypeRange{}, - mlir::ValueRange({outputBuffer, adaptor.getCiphertexts(), - adaptor.getLookupTable(), crtDecompValue, lweDimCst, - cbsLevelCountCst, cbsBaseLogCst, kskLevelCountCst, - kskBaseLogCst, bskLevelCountCst, bskBaseLogCst, - fpkskLevelCountCst, fpkskBaseLogCst, polySizeCst})); + mlir::ValueRange({castedOutputBuffer, castedCiphertexts, castedLut, + crtDecompValue, lweDimCst, cbsLevelCountCst, + cbsBaseLogCst, kskLevelCountCst, kskBaseLogCst, + bskLevelCountCst, bskBaseLogCst, fpkskLevelCountCst, + fpkskBaseLogCst, polySizeCst, glweDimCst})); rewriter.replaceOp(wopPbs, outputBuffer); @@ -542,7 +572,8 @@ void SimulateTFHEPass::runOnOperation() { target.addLegalDialect(); target.addLegalOp(); + mlir::memref::CastOp, mlir::bufferization::AllocTensorOp, + mlir::tensor::CastOp>(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect(); diff --git a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp index 08cfbf6d92..3f7d87e2c0 100644 --- a/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp +++ b/compilers/concrete-compiler/compiler/lib/Runtime/simulation.cpp @@ -112,8 +112,70 @@ void sim_wop_pbs_crt( uint32_t lwe_small_dim, uint32_t cbs_level_count, uint32_t cbs_base_log, uint32_t ksk_level_count, uint32_t ksk_base_log, uint32_t bsk_level_count, uint32_t bsk_base_log, uint32_t fpksk_level_count, uint32_t fpksk_base_log, - uint32_t polynomial_size) { - // TODO + uint32_t polynomial_size, uint32_t glwe_dim) { + + // Check number of blocks + assert(out_size == in_size && out_size == crt_decomp_size); + + uint64_t log_poly_size = + static_cast(ceil(log2(static_cast(polynomial_size)))); + + // Compute the numbers of bits to extract for each block and the total one. + uint64_t total_number_of_bits_per_block = 0; + auto number_of_bits_per_block = new uint64_t[crt_decomp_size](); + for (uint64_t i = 0; i < crt_decomp_size; i++) { + uint64_t modulus = crt_decomp_aligned[i + crt_decomp_offset]; + uint64_t nb_bit_to_extract = + static_cast(ceil(log2(static_cast(modulus)))); + number_of_bits_per_block[i] = nb_bit_to_extract; + + total_number_of_bits_per_block += nb_bit_to_extract; + } + + // Create the buffer of ciphertexts for storing the total number of bits to + // extract. + // The extracted bit should be in the following order: + // + // [msb(m%crt[n-1])..lsb(m%crt[n-1])...msb(m%crt[0])..lsb(m%crt[0])] where n + // is the size of the crt decomposition + auto extract_bits_output_buffer = + new uint64_t[total_number_of_bits_per_block]{0}; + + // Extraction of each bit for each block + for (int64_t i = crt_decomp_size - 1, extract_bits_output_offset = 0; i >= 0; + extract_bits_output_offset += number_of_bits_per_block[i--]) { + auto nb_bits_to_extract = number_of_bits_per_block[i]; + + size_t delta_log = 64 - nb_bits_to_extract; + + auto in_block = in_aligned[in_offset + i]; + + // trick ( ct - delta/2 + delta/2^4 ) + uint64_t sub = (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 1)) - + (uint64_t(1) << (uint64_t(64) - nb_bits_to_extract - 5)); + in_block -= sub; + + simulation_extract_bit_lwe_ciphertext_u64( + &extract_bits_output_buffer[extract_bits_output_offset], in_block, + delta_log, nb_bits_to_extract, log_poly_size, glwe_dim, lwe_small_dim, + ksk_base_log, ksk_level_count, bsk_base_log, bsk_level_count, 64, 128); + } + + size_t ct_in_count = total_number_of_bits_per_block; + size_t lut_size = 1 << ct_in_count; + size_t ct_out_count = out_size; + size_t lut_count = ct_out_count; + + assert(lut_ct_size0 == lut_count); + assert(lut_ct_size1 == lut_size); + + // Vertical packing + simulation_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64( + extract_bits_output_buffer, out_aligned + out_offset, ct_in_count, + ct_out_count, lut_size, lut_count, lut_ct_aligned + lut_ct_offset, + glwe_dim, log_poly_size, lwe_small_dim, bsk_level_count, bsk_base_log, + cbs_level_count, cbs_base_log, fpksk_level_count, fpksk_base_log, 64, + 128); } uint64_t sim_neg_lwe_u64(uint64_t plaintext) { return ~plaintext + 1; } diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index 2a17483b4a..353b923455 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -170,6 +170,22 @@ def compile_run_assert( ).reshape((4, 4)), id="matul_chain_with_crt", ), + pytest.param( + """ + func.func @main(%arg0: !FHE.eint<14>, %arg1: tensor<16384xi64>) -> !FHE.eint<14> { + %cst = arith.constant 15 : i15 + %v = "FHE.add_eint_int"(%arg0, %cst): (!FHE.eint<14>, i15) -> (!FHE.eint<14>) + %1 = "FHE.apply_lookup_table"(%v, %arg1): (!FHE.eint<14>, tensor<16384xi64>) -> (!FHE.eint<14>) + return %1: !FHE.eint<14> + } + """, + ( + 81, + np.array(range(16384), dtype=np.uint64), + ), + 96, + id="add_lut_crt", + ) ] end_to_end_parallel_fixture = [