Skip to content

Commit

Permalink
fix: raise warnings and use fallback for nan scores in eq source fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Nov 19, 2024
1 parent 7495204 commit ad14dc0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 31 deletions.
35 changes: 20 additions & 15 deletions src/invert4geom/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,32 +1095,37 @@ def eq_sources_score(
score = np.nan
n_splits = 5
while np.isnan(score):
score = np.mean(
vd.cross_val_score(
eqs,
coordinates,
data,
delayed=delayed,
weights=weights,
cv=sklearn.model_selection.KFold(
n_splits=n_splits,
shuffle=True,
random_state=0,
),
try:
score = np.mean(
vd.cross_val_score(
eqs,
coordinates,
data,
delayed=delayed,
weights=weights,
cv=sklearn.model_selection.KFold(
n_splits=n_splits,
shuffle=True,
random_state=0,
),
)
)
)
except ValueError:
score = np.nan
if (n_splits == 5) and (np.isnan(score)):
msg = (
"eq sources score is NaN, reduce n_splits (5) by 1 until "
"eq sources score is NaN, reducing n_splits (5) by 1 until "
"scoring metric is defined"
)
log.warning(msg)

n_splits -= 1
if n_splits == 0:
break

if np.isnan(score):
msg = (
"score is still NaN after reduce n_splits, makes sure you're supplying "
"score is still NaN after reducing n_splits, makes sure you're supplying "
"enough points for the equivalent sources"
)
raise ValueError(msg)
Expand Down
19 changes: 13 additions & 6 deletions src/invert4geom/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,12 +1902,19 @@ def __call__(self, trial: optuna.trial) -> float:
log=True,
)

return cross_validation.eq_sources_score(
damping=damping,
depth=depth,
block_size=block_size,
**kwargs,
)
try:
score = cross_validation.eq_sources_score(
damping=damping,
depth=depth,
block_size=block_size,
**kwargs,
)
except ValueError as e:
log.error(e)
msg = "score could not be calculated, returning NaN"
log.warning(msg)
score = np.nan
return score


def optimize_eq_source_params(
Expand Down
40 changes: 30 additions & 10 deletions src/invert4geom/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,36 @@ def regional_constraints(

if cv is True:
# eqs = utils.best_equivalent_source_damping(
_, eqs = optimization.optimize_eq_source_params(
coordinates=coords,
data=constraints_df.sampled_grav,
# kwargs
weights=weights,
depth=depth,
damping=damping,
block_size=block_size,
**cv_kwargs, # type: ignore[arg-type]
)
try:
_, eqs = optimization.optimize_eq_source_params(
coordinates=coords,
data=constraints_df.sampled_grav,
# kwargs
weights=weights,
depth=depth,
damping=damping,
block_size=block_size,
**cv_kwargs, # type: ignore[arg-type]
)
except ValueError as e:
log.error(e)
msg = (
"eq sources optimization failed, using damping=None and "
"depth='default'"
)
log.error(msg)
eqs = hm.EquivalentSources(
depth="default",
damping=None,
block_size=block_size,
points=points,
)
eqs.fit(
coords,
constraints_df.sampled_grav,
weights=weights,
)

else:
# create set of deep sources
eqs = hm.EquivalentSources(
Expand Down

0 comments on commit ad14dc0

Please sign in to comment.