diff --git a/sfacts/evaluation.py b/sfacts/evaluation.py index 564d292..e8a9073 100644 --- a/sfacts/evaluation.py +++ b/sfacts/evaluation.py @@ -1,5 +1,5 @@ from scipy.spatial.distance import squareform -from scipy.stats import hmean +from scipy.stats import hmean, binom from sfacts.math import ( genotype_cdist, entropy, @@ -8,7 +8,7 @@ neighbor_joining, unifrac_pdist, ) -from sfacts.data import Genotype +from sfacts.data import Genotype, Metagenotype import pandas as pd import numpy as np @@ -181,6 +181,40 @@ def metagenotype_error2(world, metagenotype=None, discretized=False): return float(err.sum() / m.sum()), mean_sample_error.to_series() +def metagenotype_entropy_error( + world, metagenotype=None, discretized=False, fuzz_eps=1e-5, montecarlo_draws=1 +): + if metagenotype is None: + metagenotype = world + metagenotype = ( + metagenotype.metagenotype + ) # In case metagenotype is a full World object. + if discretized: + g = world.genotype.discretized().fuzzed(fuzz_eps).data + else: + g = world.genotype.data + p = world.community.data @ g + m = metagenotype.total_counts().astype(int) + mu = m.mean("position") + + obs_entropy = metagenotype.entropy() + err_accum = 0 + for i in range(montecarlo_draws): + sim = binom(m, p).rvs() + sim_mgtp = Metagenotype.from_counts_and_totals( + sim, + m, + coords=dict(sample=metagenotype.sample, position=metagenotype.position), + ) + sim_sample_entropy = sim_mgtp.entropy() + err = obs_entropy - sim_sample_entropy + err_accum += err + + err = err_accum / montecarlo_draws + + return ((err * mu).sum() / mu.sum()).values, err.to_series() + + def rank_abundance_error(reference, estimate, p=1): reference_num_strains = len(reference.strain) estimate_num_strains = len(estimate.strain)