Skip to content

Commit

Permalink
v0.0.11
Browse files Browse the repository at this point in the history
  • Loading branch information
fraterenz committed Jul 30, 2024
1 parent b95149a commit 0b831dd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "futils"
version = "0.0.10"
version = "0.0.11"
authors = [
{ name="Francesco Terenzi", email="fra.terenz1993@gmail.com" },
]
Expand Down
30 changes: 23 additions & 7 deletions src/futils/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from . import snapshot
from abc import ABC, abstractmethod
from scipy import stats
from typing import List, NewType, Tuple
from typing import List, NewType, Set, Tuple


PosteriorIdx = NewType("PosteriorIdx", Set[int])
Posterior = NewType("Posterior", pd.Series)


Expand All @@ -18,15 +19,32 @@ def __subclasshook__(cls, subclass):
and callable(subclass.distance)
or NotImplemented
)

@abstractmethod
def distance(self, target: snapshot.Histogram, simulation: snapshot.Histogram) -> float:
def distance(
self, target: snapshot.Histogram, simulation: snapshot.Histogram
) -> float:
raise NotImplementedError



Stats = NewType("Stats", List[Stat])


def filter_runs_stat(
summary: pd.DataFrame, quantile: float, stat: Stat
) -> PosteriorIdx:
stat_name = stat.__class__.__name__
assert stat_name in set(
summary.columns
), f"metric {stat_name} not found in df with cols {set(summary.columns)}"
idx = summary.loc[
summary[stat_name] <= summary[stat_name].quantile(quantile), "idx"
]
idx_set = set(idx.idx.unique())
assert idx.shape[0] == len(idx_set)
return PosteriorIdx(idx_set)


@Stat.register
class Wasserstein:
def __init__(self) -> None:
Expand All @@ -47,9 +65,7 @@ def distance(self, target: snapshot.Histogram, sim: snapshot.Histogram) -> float
u_values, u_weights = list(target_uniformised.keys()), list(
target_uniformised.values()
)
return stats.wasserstein_distance(
u_values, v_values, u_weights, v_weights
)
return stats.wasserstein_distance(u_values, v_values, u_weights, v_weights)


def round_estimates(estimate: float, significant: str) -> str:
Expand Down Expand Up @@ -120,7 +136,7 @@ def plot_posterior(
color,
fancy: bool,
legend: bool = False,
xlim = None
xlim=None,
):
# https://matplotlib.org/stable/gallery/lines_bars_and_markers/stairs_demo.html
values = bins.compute_hist(posterior)
Expand Down

0 comments on commit 0b831dd

Please sign in to comment.