Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(compiler/frontend): add flag to enable/disable overflow detection in simulation #846

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ struct CompilationOptions {
/// Simulate options
bool simulate;

/// Enable overflow detection during simulation
bool enableOverflowDetectionInSimulation;

/// Parallelization options
bool autoParallelize;
bool loopParallelize;
Expand Down Expand Up @@ -110,7 +113,7 @@ struct CompilationOptions {
CompilationOptions()
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
/// Simulate options
simulate(false),
simulate(false), enableOverflowDetectionInSimulation(false),
// Parallelization options
autoParallelize(false), loopParallelize(true),
dataflowParallelize(false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ transformTFHEOperations(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
bool enableOverflowDetection,
std::function<bool(mlir::Pass *)> enablePass);

mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def("set_print_tlu_fusing",
[](CompilationOptions &options, bool printTluFusing) {
options.printTluFusing = printTluFusing;
})
.def("set_enable_overflow_detection_in_simulation",
[](CompilationOptions &options, bool enableOverflowDetection) {
options.enableOverflowDetectionInSimulation =
enableOverflowDetection;
});

pybind11::enum_<mlir::concretelang::PrimitiveOperation>(m,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,20 @@ def set_print_tlu_fusing(self, print_tlu_fusing: bool):
if not isinstance(print_tlu_fusing, bool):
raise TypeError("need to pass a boolean value")
self.cpp().set_print_tlu_fusing(print_tlu_fusing)

def set_enable_overflow_detection_in_simulation(
self, enable_overflow_detection: bool
):
"""Enable or disable overflow detection during simulation.
Args:
enable_overflow_detection (bool): flag to enable or disable overflow detection
Raises:
TypeError: if the value to set is not bool
"""
if not isinstance(enable_overflow_detection, bool):
raise TypeError("need to pass a boolean value")
self.cpp().set_enable_overflow_detection_in_simulation(
enable_overflow_detection
)
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target,

if (options.simulate) {
if (mlir::concretelang::pipeline::simulateTFHE(
mlirContext, module, res.fheContext, this->enablePass)
mlirContext, module, res.fheContext,
options.enableOverflowDetectionInSimulation, this->enablePass)
.failed()) {
return StreamStringError("Simulating TFHE failed");
}
Expand Down
4 changes: 2 additions & 2 deletions compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,13 @@ transformTFHEOperations(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult simulateTFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module,
std::optional<V0FHEContext> &fheContext,
bool enableOverflowDetection,
std::function<bool(mlir::Pass *)> enablePass) {
mlir::PassManager pm(&context);

// we want to disable overflow detection if CRT is used (overflow would be
// expected)
bool enableOverflowDetection = true;
if (fheContext) {
if (fheContext && enableOverflowDetection) {
auto solution = fheContext.value().solution;
auto optCrt = getCrtDecompositionFromSolution(solution);
if (optCrt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def compile_run_assert(
):
# compile with simulation
options.simulation(True)
options.set_enable_overflow_detection_in_simulation(True)
compilation_result = engine.compile(mlir_input, options)
result = run_simulated(engine, args_and_shape, compilation_result)
assert_result(result, expected_result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ class Configuration:
enable_tlu_fusing: bool
print_tlu_fusing: bool
optimize_tlu_based_on_original_bit_width: Union[bool, int]
detect_overflow_in_simulation: bool

def __init__(
self,
Expand Down Expand Up @@ -1055,6 +1056,7 @@ def __init__(
enable_tlu_fusing: bool = True,
print_tlu_fusing: bool = False,
optimize_tlu_based_on_original_bit_width: Union[bool, int] = 8,
detect_overflow_in_simulation: bool = False,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
Expand Down Expand Up @@ -1155,6 +1157,8 @@ def __init__(

self.optimize_tlu_based_on_original_bit_width = optimize_tlu_based_on_original_bit_width

self.detect_overflow_in_simulation = detect_overflow_in_simulation

self._validate()

class Keep:
Expand Down Expand Up @@ -1198,14 +1202,17 @@ def fork(
compiler_debug_mode: Union[Keep, bool] = KEEP,
compiler_verbose_mode: Union[Keep, bool] = KEEP,
comparison_strategy_preference: Union[
Keep, Optional[Union[ComparisonStrategy, str, List[Union[ComparisonStrategy, str]]]]
Keep,
Optional[Union[ComparisonStrategy, str, List[Union[ComparisonStrategy, str]]]],
] = KEEP,
bitwise_strategy_preference: Union[
Keep, Optional[Union[BitwiseStrategy, str, List[Union[BitwiseStrategy, str]]]]
Keep,
Optional[Union[BitwiseStrategy, str, List[Union[BitwiseStrategy, str]]]],
] = KEEP,
shifts_with_promotion: Union[Keep, bool] = KEEP,
multivariate_strategy_preference: Union[
Keep, Optional[Union[MultivariateStrategy, str, List[Union[MultivariateStrategy, str]]]]
Keep,
Optional[Union[MultivariateStrategy, str, List[Union[MultivariateStrategy, str]]]],
] = KEEP,
min_max_strategy_preference: Union[
Keep, Optional[Union[MinMaxStrategy, str, List[Union[MinMaxStrategy, str]]]]
Expand All @@ -1223,6 +1230,7 @@ def fork(
enable_tlu_fusing: Union[Keep, bool] = KEEP,
print_tlu_fusing: Union[Keep, bool] = KEEP,
optimize_tlu_based_on_original_bit_width: Union[Keep, bool, int] = KEEP,
detect_overflow_in_simulation: Union[Keep, bool] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.
Expand Down
10 changes: 9 additions & 1 deletion frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def create(
options.set_compress_evaluation_keys(configuration.compress_evaluation_keys)
options.set_compress_input_ciphertexts(configuration.compress_input_ciphertexts)
options.set_composable(configuration.composable)
options.set_enable_overflow_detection_in_simulation(
configuration.detect_overflow_in_simulation
)

if configuration.auto_parallelize or configuration.dataflow_parallelize:
# pylint: disable=c-extension-no-member,no-member
Expand Down Expand Up @@ -319,7 +322,12 @@ def load(path: Union[str, Path]) -> "Server":
server_program = ServerProgram.load(support, is_simulated)

return Server(
client_specs, output_dir, support, compilation_result, server_program, is_simulated
client_specs,
output_dir,
support,
compilation_result,
server_program,
is_simulated,
)

def run(
Expand Down
Loading