Skip to content

Commit

Permalink
feat: allow weight_by=None in merged_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
mdtanker committed Nov 19, 2024
1 parent cc825e0 commit d96e5c2
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions src/invert4geom/uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def regional_misfit_uncertainty(
plot: bool = True,
plot_region: tuple[float, float, float, float] | None = None,
true_regional: xr.DataArray | None = None,
weight_by: str | None = None,
**kwargs: typing.Any,
) -> xr.Dataset:
"""
Expand Down Expand Up @@ -368,9 +369,26 @@ def regional_misfit_uncertainty(
# merge all topos into 1 dataset
merged = merge_simulation_results(regional_grids)

# get constraint point RMSE of each model
if weight_by == "constraints":
weight_vals = []
for g in regional_grids:
points = utils.sample_grids(
constraints_df,
g,
sampled_name="sampled_regional",
coord_names=["easting", "northing"],
)
weight_vals.append(utils.rmse(points.sampled_regional))
# convert residuals into weights
weights = [1 / (x**2) for x in weight_vals]
else:
weights = None

# get stats and weighted stats on the merged dataset
stats_ds = model_ensemble_stats(
merged,
weights=weights,
)

if plot is True:
Expand Down Expand Up @@ -886,6 +904,8 @@ def merged_stats(
if weight_by == "residual":
# get the RMS of the final gravity residual of each model
weight_vals = [utils.rmse(df[list(df.columns)[-1]]) for df in grav_dfs]
# convert residuals into weights
weights = [1 / (x**2) for x in weight_vals]
# get constraint point RMSE of each model
elif weight_by == "constraints":
weight_vals = []
Expand All @@ -900,10 +920,10 @@ def merged_stats(
)
points["dif"] = points.upward - points.sampled_topo
weight_vals.append(utils.rmse(points.dif))

# convert residuals into weights
weights = [1 / (x**2) for x in weight_vals]

# convert residuals into weights
weights = [1 / (x**2) for x in weight_vals]
else:
weights = None
# get stats and weighted stats on the merged dataset
stats_ds = model_ensemble_stats(
merged,
Expand Down

0 comments on commit d96e5c2

Please sign in to comment.