From 2356c3f9c66374ae3d10f0061b10d9ee64d99e46 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 23 May 2024 16:41:27 +0100 Subject: [PATCH] feat(compiler/frontend): add flag to enable/disable overflow detection in simulation --- .../concretelang/Support/CompilerEngine.h | 5 ++++- .../include/concretelang/Support/Pipeline.h | 1 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 5 +++++ .../concrete/compiler/compilation_options.py | 17 +++++++++++++++++ .../compiler/lib/Support/CompilerEngine.cpp | 3 ++- .../compiler/lib/Support/Pipeline.cpp | 4 ++-- .../compiler/tests/python/test_simulation.py | 1 + .../concrete/fhe/compilation/configuration.py | 14 +++++++++++--- .../concrete/fhe/compilation/server.py | 10 +++++++++- 9 files changed, 52 insertions(+), 8 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h index 4100a43141..ae56ea9415 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilerEngine.h @@ -67,6 +67,9 @@ struct CompilationOptions { /// Simulate options bool simulate; + /// Enable overflow detection during simulation + bool enableOverflowDetectionInSimulation; + /// Parallelization options bool autoParallelize; bool loopParallelize; @@ -110,7 +113,7 @@ struct CompilationOptions { CompilationOptions() : v0FHEConstraints(std::nullopt), verifyDiagnostics(false), /// Simulate options - simulate(false), + simulate(false), enableOverflowDetectionInSimulation(false), // Parallelization options autoParallelize(false), loopParallelize(true), dataflowParallelize(false), diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index 4953e69bb5..9c9e002396 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -109,6 +109,7 @@ transformTFHEOperations(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::optional &fheContext, + bool enableOverflowDetection, std::function enablePass); mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context, diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index f5d9734756..c53b2dca56 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -803,6 +803,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def("set_print_tlu_fusing", [](CompilationOptions &options, bool printTluFusing) { options.printTluFusing = printTluFusing; + }) + .def("set_enable_overflow_detection_in_simulation", + [](CompilationOptions &options, bool enableOverflowDetection) { + options.enableOverflowDetectionInSimulation = + enableOverflowDetection; }); pybind11::enum_(m, diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py index 3676b8b29d..7759718e70 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py @@ -481,3 +481,20 @@ def set_print_tlu_fusing(self, print_tlu_fusing: bool): if not isinstance(print_tlu_fusing, bool): raise TypeError("need to pass a boolean value") self.cpp().set_print_tlu_fusing(print_tlu_fusing) + + def set_enable_overflow_detection_in_simulation( + self, enable_overflow_detection: bool + ): + """Enable or disable overflow detection during simulation. + + Args: + enable_overflow_detection (bool): flag to enable or disable overflow detection + + Raises: + TypeError: if the value to set is not bool + """ + if not isinstance(enable_overflow_detection, bool): + raise TypeError("need to pass a boolean value") + self.cpp().set_enable_overflow_detection_in_simulation( + enable_overflow_detection + ) diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index 0eab6a7206..ee8f938958 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -520,7 +520,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, if (options.simulate) { if (mlir::concretelang::pipeline::simulateTFHE( - mlirContext, module, res.fheContext, this->enablePass) + mlirContext, module, res.fheContext, + options.enableOverflowDetectionInSimulation, this->enablePass) .failed()) { return StreamStringError("Simulating TFHE failed"); } diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 67093ee8f6..3a51aff910 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -439,13 +439,13 @@ transformTFHEOperations(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, std::optional &fheContext, + bool enableOverflowDetection, std::function enablePass) { mlir::PassManager pm(&context); // we want to disable overflow detection if CRT is used (overflow would be // expected) - bool enableOverflowDetection = true; - if (fheContext) { + if (fheContext && enableOverflowDetection) { auto solution = fheContext.value().solution; auto optCrt = getCrtDecompositionFromSolution(solution); if (optCrt) { diff --git a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py index 4755b7a4f9..6d36f8aa0c 100644 --- a/compilers/concrete-compiler/compiler/tests/python/test_simulation.py +++ b/compilers/concrete-compiler/compiler/tests/python/test_simulation.py @@ -57,6 +57,7 @@ def compile_run_assert( ): # compile with simulation options.simulation(True) + options.set_enable_overflow_detection_in_simulation(True) compilation_result = engine.compile(mlir_input, options) result = run_simulated(engine, args_and_shape, compilation_result) assert_result(result, expected_result) diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 3d89409eda..85bf865e2f 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -990,6 +990,7 @@ class Configuration: enable_tlu_fusing: bool print_tlu_fusing: bool optimize_tlu_based_on_original_bit_width: Union[bool, int] + detect_overflow_in_simulation: bool def __init__( self, @@ -1055,6 +1056,7 @@ def __init__( enable_tlu_fusing: bool = True, print_tlu_fusing: bool = False, optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8, + detect_overflow_in_simulation: bool = False, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1155,6 +1157,8 @@ def __init__( self.optimize_tlu_based_on_original_bit_width = optimize_tlu_based_on_original_bit_width + self.detect_overflow_in_simulation = detect_overflow_in_simulation + self._validate() class Keep: @@ -1198,14 +1202,17 @@ def fork( compiler_debug_mode: Union[Keep, bool] = KEEP, compiler_verbose_mode: Union[Keep, bool] = KEEP, comparison_strategy_preference: Union[ - Keep, Optional[Union[ComparisonStrategy, str, List[Union[ComparisonStrategy, str]]]] + Keep, + Optional[Union[ComparisonStrategy, str, List[Union[ComparisonStrategy, str]]]], ] = KEEP, bitwise_strategy_preference: Union[ - Keep, Optional[Union[BitwiseStrategy, str, List[Union[BitwiseStrategy, str]]]] + Keep, + Optional[Union[BitwiseStrategy, str, List[Union[BitwiseStrategy, str]]]], ] = KEEP, shifts_with_promotion: Union[Keep, bool] = KEEP, multivariate_strategy_preference: Union[ - Keep, Optional[Union[MultivariateStrategy, str, List[Union[MultivariateStrategy, str]]]] + Keep, + Optional[Union[MultivariateStrategy, str, List[Union[MultivariateStrategy, str]]]], ] = KEEP, min_max_strategy_preference: Union[ Keep, Optional[Union[MinMaxStrategy, str, List[Union[MinMaxStrategy, str]]]] @@ -1223,6 +1230,7 @@ def fork( enable_tlu_fusing: Union[Keep, bool] = KEEP, print_tlu_fusing: Union[Keep, bool] = KEEP, optimize_tlu_based_on_original_bit_width: Union[Keep, bool, int] = KEEP, + detect_overflow_in_simulation: Union[Keep, bool] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index b940676169..6f31663f2a 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -124,6 +124,9 @@ def create( options.set_compress_evaluation_keys(configuration.compress_evaluation_keys) options.set_compress_input_ciphertexts(configuration.compress_input_ciphertexts) options.set_composable(configuration.composable) + options.set_enable_overflow_detection_in_simulation( + configuration.detect_overflow_in_simulation + ) if configuration.auto_parallelize or configuration.dataflow_parallelize: # pylint: disable=c-extension-no-member,no-member @@ -319,7 +322,12 @@ def load(path: Union[str, Path]) -> "Server": server_program = ServerProgram.load(support, is_simulated) return Server( - client_specs, output_dir, support, compilation_result, server_program, is_simulated + client_specs, + output_dir, + support, + compilation_result, + server_program, + is_simulated, ) def run(