From 545bda979d87331869b5f37f74e6f9cbd3b5ba65 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 21 Jul 2023 15:08:48 +0100 Subject: [PATCH] fix(compiler): use dyn sized tensors in CAPI func definitions --- .../Conversion/SimulateTFHE/SimulateTFHE.cpp | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp index 8d4ae390db..be7cc62d31 100644 --- a/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp +++ b/compilers/concrete-compiler/compiler/lib/Conversion/SimulateTFHE/SimulateTFHE.cpp @@ -54,6 +54,13 @@ class SimulateTFHETypeConverter : public mlir::TypeConverter { namespace { +mlir::RankedTensorType toDynamicTensorType(mlir::TensorType staticSizedTensor) { + std::vector dynSizedShape(staticSizedTensor.getShape().size(), + mlir::ShapedType::kDynamic); + return mlir::RankedTensorType::get(dynSizedShape, + staticSizedTensor.getElementType()); +} + struct NegOpPattern : public mlir::OpConversionPattern { NegOpPattern(mlir::MLIRContext *context, mlir::TypeConverter &typeConverter) @@ -131,6 +138,18 @@ struct EncodeExpandLutForBootstrapOpPattern eeOp.getResult().getType().cast(), mlir::ValueRange{}); + auto dynamicResultType = toDynamicTensorType(eeOp.getResult().getType()); + auto dynamicLutType = + toDynamicTensorType(eeOp.getInputLookupTable().getType()); + + mlir::Value castedOutputBuffer = rewriter.create( + eeOp.getLoc(), dynamicResultType, outputBuffer); + + mlir::Value castedLUT = rewriter.create( + eeOp.getLoc(), + toDynamicTensorType(eeOp.getInputLookupTable().getType()), + adaptor.getInputLookupTable()); + // sim_encode_expand_lut_for_boostrap(uint64_t *out_allocated, uint64_t // *out_aligned, uint64_t out_offset, uint64_t out_size, uint64_t // out_stride, uint64_t *in_allocated, uint64_t *in_aligned, uint64_t @@ -138,20 +157,18 @@ struct EncodeExpandLutForBootstrapOpPattern // uint32_t output_bits, bool is_signed) if (insertForwardDeclaration( eeOp, rewriter, funcName, - rewriter.getFunctionType({eeOp.getResult().getType(), - eeOp.getInputLookupTable().getType(), - rewriter.getIntegerType(32), - rewriter.getIntegerType(32), - rewriter.getIntegerType(1)}, - {})) + rewriter.getFunctionType( + {dynamicResultType, dynamicLutType, rewriter.getIntegerType(32), + rewriter.getIntegerType(32), rewriter.getIntegerType(1)}, + {})) .failed()) { return mlir::failure(); } rewriter.create( eeOp.getLoc(), funcName, mlir::TypeRange{}, - mlir::ValueRange({outputBuffer, adaptor.getInputLookupTable(), - polySizeCst, outputBitsCst, isSignedCst})); + mlir::ValueRange({castedOutputBuffer, castedLUT, polySizeCst, + outputBitsCst, isSignedCst})); rewriter.replaceOp(eeOp, outputBuffer); @@ -244,6 +261,11 @@ struct BootstrapGLWEOpPattern auto inputLweDimensionCst = rewriter.create( bsOp.getLoc(), inputLweDimension, 32); + auto dynamicLutType = toDynamicTensorType(bsOp.getLookupTable().getType()); + + mlir::Value castedLUT = rewriter.create( + bsOp.getLoc(), dynamicLutType, adaptor.getLookupTable()); + // uint64_t sim_bootstrap_lwe_u64(uint64_t plaintext, uint64_t // *tlu_allocated, uint64_t *tlu_aligned, uint64_t tlu_offset, uint64_t // tlu_size, uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t @@ -251,7 +273,7 @@ struct BootstrapGLWEOpPattern if (insertForwardDeclaration( bsOp, rewriter, funcName, rewriter.getFunctionType( - {rewriter.getIntegerType(64), bsOp.getLookupTable().getType(), + {rewriter.getIntegerType(64), dynamicLutType, rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32), rewriter.getIntegerType(32)}, @@ -262,7 +284,7 @@ struct BootstrapGLWEOpPattern rewriter.replaceOpWithNewOp( bsOp, funcName, this->getTypeConverter()->convertType(resultType), - mlir::ValueRange({adaptor.getCiphertext(), adaptor.getLookupTable(), + mlir::ValueRange({adaptor.getCiphertext(), castedLUT, inputLweDimensionCst, polySizeCst, levelsCst, baseLogCst, glweDimensionCst})); @@ -376,7 +398,8 @@ void SimulateTFHEPass::runOnOperation() { SimulateTFHETypeConverter converter; target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // Make sure that no ops from `TFHE` remain after the lowering target.addIllegalDialect();