Skip to content

Commit

Permalink
feat(frontend-python): multi-parameters, Configuration, by-precision-…
Browse files Browse the repository at this point in the history
…and-norm2 strategy
  • Loading branch information
rudy-6-4 committed Jan 3, 2024
1 parent 3a992cf commit 6f78608
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.value("DAG_MULTI", optimizer::Strategy::DAG_MULTI)
.export_values();

pybind11::enum_<concrete_optimizer::MultiParamStrategy>(
m, "OptimizerMultiParameterStrategy")
.value("PRECISION", concrete_optimizer::MultiParamStrategy::ByPrecision)
.value("PRECISION_AND_NORM2",
concrete_optimizer::MultiParamStrategy::ByPrecisionAndNorm2)
.export_values();

pybind11::enum_<concrete_optimizer::Encoding>(m, "Encoding")
.value("AUTO", concrete_optimizer::Encoding::Auto)
.value("CRT", concrete_optimizer::Encoding::Crt)
Expand Down Expand Up @@ -107,6 +114,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CompilationOptions &options, optimizer::Strategy strategy) {
options.optimizerConfig.strategy = strategy;
})
.def("set_optimizer_multi_parameter_strategy",
[](CompilationOptions &options,
concrete_optimizer::MultiParamStrategy strategy) {
options.optimizerConfig.multi_param_strategy = strategy;
})
.def("set_global_p_error",
[](CompilationOptions &options, double global_p_error) {
options.optimizerConfig.global_p_error = global_p_error;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mlir._mlir_libs._concretelang._compiler import (
CompilationOptions as _CompilationOptions,
OptimizerStrategy as _OptimizerStrategy,
OptimizerMultiParameterStrategy as _OptimizerMultiParameterStrategy,
Encoding,
)
from .wrapper import WrapperCpp
Expand Down Expand Up @@ -190,12 +191,27 @@ def set_optimizer_strategy(self, strategy: _OptimizerStrategy):
strategy (OptimizerStrategy): Use the specified optmizer strategy.
Raises:
TypeError: if the value is not a bool
TypeError: if the value is not an OptimizerStrategy
"""
if not isinstance(strategy, _OptimizerStrategy):
raise TypeError("enable should be a bool")
self.cpp().set_optimizer_strategy(strategy)

def set_optimizer_multi_parameter_strategy(
self, strategy: _OptimizerMultiParameterStrategy
):
"""Set the strategy of the optimizer for multi-parameter.
Args:
strategy (OptimizerMultiParameterStrategy): Use the specified optmizer multi-parameter strategy.
Raises:
TypeError: if the value is not a OptimizerMultiParameterStrategy
"""
if not isinstance(strategy, _OptimizerMultiParameterStrategy):
raise TypeError("enable should be a bool")
self.cpp().set_optimizer_multi_parameter_strategy(strategy)

def set_global_p_error(self, global_p_error: float):
"""Set global error probability for the full circuit.
Expand Down
4 changes: 4 additions & 0 deletions docs/howto/configure.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* Use single precision for the whole circuit.
* **parameter\_selection\_strategy**: (fhe.ParameterSelectionStrategy) = fhe.ParameterSelectionStrategy.MULTI
* Set how cryptographic parameters are selected.
* **multi\_parameter\_strategy**: fhe.MultiParameterStrategy = fhe.MultiParameterStrategy.PRECISION
* Set how cryptographic parameters are added and assigned in the circuit for the multi-parameters selection strategy.
* `PRECISION`: all TLU with same input precision have their own parameters.
* `PRECISION_AND_NORM2`: all TLU with same input precision and output [norm2](../../compilers/concrete-optimizer/v0-parameters/README.md) have their own parameters.
* **jit**: bool = False
* Enable JIT compilation.
* **loop\_parallelize**: bool = True
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorial/multi_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ When multi parameters are enabled, a different set of parameters are selected fo
- Larger memory usage during execution.

To disable it, you can use `parameter_selection_strategy=fhe.ParameterSelectionStrategy.MONO` configuration option.

When enabled, you can choose a strategy on how the multiple parameters are added and assigned in the circuit, see **multi\_parameter\_strategy** in [configuration](../howto/configure.md#options).
1 change: 1 addition & 0 deletions frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
EncryptionStatus,
Keys,
MinMaxStrategy,
MultiParameterStrategy,
MultivariateStrategy,
ParameterSelectionStrategy,
Server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ComparisonStrategy,
Configuration,
MinMaxStrategy,
MultiParameterStrategy,
MultivariateStrategy,
ParameterSelectionStrategy,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,32 @@ def parse(cls, string: str) -> "ParameterSelectionStrategy":
raise ValueError(message)


class MultiParameterStrategy(str, Enum):
"""
MultiParamStrategy, to set optimization strategy for multi-parameter.
"""

PRECISION = "precision"
PRECISION_AND_NORM2 = "precision_and_norm2"

@classmethod
def parse(cls, string: str) -> "MultiParameterStrategy":
"""Convert a string to a MultiParamStrategy."""
if isinstance(string, cls):
return string
if not isinstance(string, str):
message = f"{string} cannot be parsed to a {cls.__name__}"
raise TypeError(message)
for value in MultiParameterStrategy:
if string.lower().replace("-", "_") == value.value:
return value
message = (
f"'{string}' is not a valid '{friendly_type_format(cls)}' ("
f"{', '.join(v.value for v in MultiParameterStrategy)})"
)
raise ValueError(message)


class ComparisonStrategy(str, Enum):
"""
ComparisonStrategy, to specify implementation preference for comparisons.
Expand Down Expand Up @@ -887,6 +913,7 @@ class Configuration:
auto_adjust_truncators: bool
single_precision: bool
parameter_selection_strategy: ParameterSelectionStrategy
multi_parameter_strategy: MultiParameterStrategy
show_progress: bool
progress_title: str
progress_tag: Union[bool, int]
Expand Down Expand Up @@ -927,6 +954,9 @@ def __init__(
parameter_selection_strategy: Union[
ParameterSelectionStrategy, str
] = ParameterSelectionStrategy.MULTI,
multi_parameter_strategy: Union[
MultiParameterStrategy, str
] = MultiParameterStrategy.PRECISION,
show_progress: bool = False,
progress_title: str = "",
progress_tag: Union[bool, int] = False,
Expand Down Expand Up @@ -978,6 +1008,7 @@ def __init__(
self.parameter_selection_strategy = ParameterSelectionStrategy.parse(
parameter_selection_strategy
)
self.multi_parameter_strategy = MultiParameterStrategy.parse(multi_parameter_strategy)
self.show_progress = show_progress
self.progress_title = progress_title
self.progress_tag = progress_tag
Expand Down Expand Up @@ -1057,6 +1088,7 @@ def fork(
auto_adjust_truncators: Union[Keep, bool] = KEEP,
single_precision: Union[Keep, bool] = KEEP,
parameter_selection_strategy: Union[Keep, Union[ParameterSelectionStrategy, str]] = KEEP,
multi_parameter_strategy: Union[Keep, Union[MultiParameterStrategy, str]] = KEEP,
show_progress: Union[Keep, bool] = KEEP,
progress_title: Union[Keep, str] = KEEP,
progress_tag: Union[Keep, Union[bool, int]] = KEEP,
Expand Down
15 changes: 14 additions & 1 deletion frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,20 @@
set_compiler_logging,
set_llvm_debug_flag,
)
from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerStrategy, PrimitiveOperation
from mlir._mlir_libs._concretelang._compiler import (
KeyType,
OptimizerMultiParameterStrategy,
OptimizerStrategy,
PrimitiveOperation,
)
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
Configuration,
MultiParameterStrategy,
ParameterSelectionStrategy,
)
from .specs import ClientSpecs
Expand Down Expand Up @@ -159,6 +165,13 @@ def create(
options.set_optimizer_strategy(OptimizerStrategy.DAG_MONO)
elif parameter_selection_strategy == ParameterSelectionStrategy.MULTI: # pragma: no cover
options.set_optimizer_strategy(OptimizerStrategy.DAG_MULTI)

multi_parameter_strategy = configuration.multi_parameter_strategy
converter = {
MultiParameterStrategy.PRECISION: OptimizerMultiParameterStrategy.PRECISION,
MultiParameterStrategy.PRECISION_AND_NORM2: OptimizerMultiParameterStrategy.PRECISION_AND_NORM2, # noqa: E501
}
options.set_optimizer_multi_parameter_strategy(converter[multi_parameter_strategy])
try:
if configuration.compiler_debug_mode: # pragma: no cover
set_llvm_debug_flag(True)
Expand Down
10 changes: 10 additions & 0 deletions frontends/concrete-python/tests/compilation/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def test_configuration_fork():
ValueError,
"'bad' is not a valid 'ParameterSelectionStrategy' (v0, mono, multi)",
),
pytest.param(
{"multi_parameter_strategy": 42},
TypeError,
"42 cannot be parsed to a MultiParameterStrategy",
),
pytest.param(
{"multi_parameter_strategy": "bad"},
ValueError,
"'bad' is not a valid 'MultiParameterStrategy' (precision, precision_and_norm2)",
),
pytest.param(
{"comparison_strategy_preference": 42},
TypeError,
Expand Down

1 comment on commit 6f78608

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 6f78608 Previous: f58c039 Ratio
v0 PBS table generation 59899817 ns/iter (± 1062102) 59913384 ns/iter (± 2099239) 1.00
v0 PBS simulate dag table generation 40644637 ns/iter (± 241557) 40566199 ns/iter (± 320862) 1.00
v0 WoP-PBS table generation 102927000 ns/iter (± 2392461) 105238697 ns/iter (± 410082) 0.98

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.