diff --git a/cmaes/_cmawm.py b/cmaes/_cmawm.py index ce4c189..93f5fcc 100644 --- a/cmaes/_cmawm.py +++ b/cmaes/_cmawm.py @@ -213,9 +213,7 @@ def ask(self) -> Tuple[np.ndarray, np.ndarray]: ) return x_encoded, x x = self._cma._sample_solution() - x[self._continuous_idx] = self._repair_continuous_params( - x[self._continuous_idx] - ) + x = self._cma._repair_infeasible_params(x) x_encoded = x.copy() x_encoded[self._discrete_idx] = self._encoding_discrete_params( x[self._discrete_idx] @@ -231,21 +229,6 @@ def _is_continuous_feasible(self, continuous_param: np.ndarray) -> bool: and np.all(continuous_param <= self._continuous_space[:, 1]), ) # Cast bool_ to bool. - def _repair_continuous_params(self, continuous_param: np.ndarray) -> np.ndarray: - if self._continuous_space is None: - return continuous_param - - # clip with lower and upper bound. - param = np.where( - continuous_param < self._continuous_space[:, 0], - self._continuous_space[:, 0], - continuous_param, - ) - param = np.where( - param > self._continuous_space[:, 1], self._continuous_space[:, 1], param - ) - return param - def _encoding_discrete_params(self, discrete_param: np.ndarray) -> np.ndarray: """Encode the values into discrete domain.""" mean = self._cma._mean