From d96e5c20b866c25da5343d1cec76ab4a5299ac44 Mon Sep 17 00:00:00 2001 From: mdtanker Date: Tue, 19 Nov 2024 10:16:20 -0500 Subject: [PATCH] feat: allow `weight_by=None` in `merged_stats` --- src/invert4geom/uncertainty.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/invert4geom/uncertainty.py b/src/invert4geom/uncertainty.py index e0c82dce..900b6ebc 100644 --- a/src/invert4geom/uncertainty.py +++ b/src/invert4geom/uncertainty.py @@ -293,6 +293,7 @@ def regional_misfit_uncertainty( plot: bool = True, plot_region: tuple[float, float, float, float] | None = None, true_regional: xr.DataArray | None = None, + weight_by: str | None = None, **kwargs: typing.Any, ) -> xr.Dataset: """ @@ -368,9 +369,26 @@ def regional_misfit_uncertainty( # merge all topos into 1 dataset merged = merge_simulation_results(regional_grids) + # get constraint point RMSE of each model + if weight_by == "constraints": + weight_vals = [] + for g in regional_grids: + points = utils.sample_grids( + constraints_df, + g, + sampled_name="sampled_regional", + coord_names=["easting", "northing"], + ) + weight_vals.append(utils.rmse(points.sampled_regional)) + # convert residuals into weights + weights = [1 / (x**2) for x in weight_vals] + else: + weights = None + # get stats and weighted stats on the merged dataset stats_ds = model_ensemble_stats( merged, + weights=weights, ) if plot is True: @@ -886,6 +904,8 @@ def merged_stats( if weight_by == "residual": # get the RMS of the final gravity residual of each model weight_vals = [utils.rmse(df[list(df.columns)[-1]]) for df in grav_dfs] + # convert residuals into weights + weights = [1 / (x**2) for x in weight_vals] # get constraint point RMSE of each model elif weight_by == "constraints": weight_vals = [] @@ -900,10 +920,10 @@ def merged_stats( ) points["dif"] = points.upward - points.sampled_topo weight_vals.append(utils.rmse(points.dif)) - - # convert residuals into weights - weights = [1 / (x**2) for x in weight_vals] - + # convert residuals into weights + weights = [1 / (x**2) for x in weight_vals] + else: + weights = None # get stats and weighted stats on the merged dataset stats_ds = model_ensemble_stats( merged,