Skip to content

Commit

Permalink
Report all issues rather than stop at first one. (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaniyaki committed Jul 10, 2024
1 parent 3879491 commit 5b97ca3
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/estimagic/parameters/check_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,56 +39,62 @@ def check_constraints_are_satisfied(flat_constraints, param_values, param_names)
typ = constr["type"]
subset = param_values[constr["index"]]

report = []

_msg = partial(_get_message, constr, param_names)

if typ == "covariance":
cov = cov_params_to_matrix(subset)
e, _ = np.linalg.eigh(cov)
if not np.all(e > -1e-8):
raise InvalidParamsError(_msg())
report.append(_msg())
elif typ == "sdcorr":
cov = sdcorr_params_to_matrix(subset)
e, _ = np.linalg.eigh(cov)
if not np.all(e > -1e-8):
raise InvalidParamsError(_msg())
report.append(_msg())
elif typ == "probability":
if not np.isclose(subset.sum(), 1, rtol=0.01):
explanation = "Probabilities do not sum to 1."
raise InvalidParamsError(_msg(explanation))
report.append(_msg(explanation))
if np.any(subset < 0):
explanation = "There are negative Probabilities."
raise InvalidParamsError(_msg(explanation))
report.append(_msg(explanation))
if np.any(subset > 1):
explanation = "There are probabilities larger than 1."
raise InvalidParamsError(_msg(explanation))
report.append(_msg(explanation))
elif typ == "fixed":
if "value" in constr and not np.allclose(subset, constr["value"]):
explanation = (
"Fixing parameters to different values than their start values "
"was allowed in earlier versions of estimagic but is "
"forbidden now. "
)
raise InvalidParamsError(_msg(explanation))
report.append(_msg(explanation))
elif typ == "increasing":
if np.any(np.diff(subset) < 0):
raise InvalidParamsError(_msg())
report.append(_msg())
elif typ == "decreasing":
if np.any(np.diff(subset) > 0):
InvalidParamsError(_msg())
report.append(_msg())
elif typ == "linear":
wsum = subset.dot(constr["weights"])
if "lower_bound" in constr and wsum < constr["lower_bound"]:
explanation = "Lower bound of linear constraint is violated."
raise InvalidParamsError(_msg(explanation))
report.append(_msg(explanation))
elif "upper_bound" in constr and wsum > constr["upper_bound"]:
explanation = "Upper bound of linear constraint violated"
raise InvalidParamsError(_msg(explanation))
report.append(_msg(explanation))
elif "value" in constr and not np.isclose(wsum, constr["value"]):
explanation = "Equality condition of linear constraint violated"
raise InvalidParamsError(_msg(explanation))
report.append(_msg(explanation))
elif typ == "equality":
if len(set(subset.tolist())) > 1:
raise InvalidParamsError(_msg())
report.append(_msg())

report = "\n".join(report)
if report != "":
raise InvalidParamsError(f"Violated constraint at start params:\n{report}")


def _get_message(constraint, param_names, explanation=""):
Expand Down
110 changes: 110 additions & 0 deletions tests/parameters/test_check_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import numpy as np
from estimagic.exceptions import InvalidParamsError
from estimagic.parameters.constraint_tools import check_constraints


def test_check_constraints_are_satisfied_type_equality():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.array([1, 2, 3]), constraints={"type": "equality", "loc": [0, 1]}
)


def test_check_constraints_are_satisfied_type_increasing():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.array([1, 2, 3, 2, 4]),
constraints={"type": "increasing", "loc": [1, 2, 3]},
)


def test_check_constraints_are_satisfied_type_decreasing():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.array([1, 2, 3, 2, 4]),
constraints={"type": "decreasing", "loc": [0, 1, 3]},
)


def test_check_constraints_are_satisfied_type_pairwise_equality():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.array([1, 2, 3, 3, 4]),
constraints={"type": "pairwise_equality", "locs": [[0, 4], [3, 2]]},
)


def test_check_constraints_are_satisfied_type_probability():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.array([0.10, 0.25, 0.50, 1, 0.7]),
constraints={"type": "probability", "loc": [0, 1, 2, 4]},
)


def test_check_constraints_are_satisfied_type_linear_lower_bound():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.ones(5),
constraints={
"type": "linear",
"loc": [0, 2, 3, 4],
"lower_bound": 1.1,
"weights": 0.25,
},
)


def test_check_constraints_are_satisfied_type_linear_upper_bound():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.ones(5),
constraints={
"type": "linear",
"loc": [0, 2, 3, 4],
"upper_bound": 0.9,
"weights": 0.25,
},
)


def test_check_constraints_are_satisfied_type_linear_value():
with pytest.raises(InvalidParamsError):
check_constraints(
params=np.ones(5),
constraints={
"type": "linear",
"loc": [0, 2, 3, 4],
"value": 2,
"weights": 0.25,
},
)


def test_check_constraints_are_satisfied_type_covariance():
with pytest.raises(InvalidParamsError):
check_constraints(
params=[1, 1, 1, -1, 1, -1],
constraints={
"type": "covariance",
# "loc": [0, 1, 2],
"selector": lambda params: params,
},
)


def test_check_constraints_are_satisfied_type_sdcorr():
with pytest.raises(InvalidParamsError):
check_constraints(
params=[1, 1, 1, -1, 1, 1],
constraints={
"type": "sdcorr",
# "loc": [0, 1, 2],
"selector": lambda params: params,
},
)


# to ignore as per email?
# def test_check_constraints_are_satisfied_type_nonlinear():

0 comments on commit 5b97ca3

Please sign in to comment.