Skip to content

Commit

Permalink
Add modified Multi-snp and Docstring (#547)
Browse files Browse the repository at this point in the history
Adds a modified multivariate version of POE for multiple SNPs

Also fixes POESingleSNP's init docstring to match parameters

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced the `POEMultipleSNP2` class with methods for improved
prediction and classification.
  
- **Tests**
- Added new test function `test_multi2_fit` to validate the
functionality of the `POEMultipleSNP2` class.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
IlhaH and coderabbitai[bot] committed Jul 18, 2024
1 parent 23c662d commit 8f1df25
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 8 deletions.
218 changes: 210 additions & 8 deletions python/python/bystro/parent_of_origin/parent_of_origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
"""
import numpy as np
import numpy.linalg as la
from typing import Tuple, Union, Optional
from typing import Tuple, Union, Optional, Dict
from numpy.typing import NDArray
from tqdm import trange

from numba import jit # type: ignore

from sklearn.utils import resample
from sklearn.utils import resample # type: ignore

from bystro.covariance.optimal_shrinkage import optimal_shrinkage
from bystro.covariance._covariance_np import (
Expand Down Expand Up @@ -150,8 +150,20 @@ def __init__(
compute_pvalue : bool, optional, default=False
Whether to compute p-values for the test.
n_permutations : int, optional, default=10000
The number of permutations to perform for significance testing.
compute_ci : bool, optional, default=False
Whether to compute confidence intervals.
store_samples : bool, optional, default=False
Whether to store bootstrap samples.
pval_method : str, optional, default="rmt4ds"
The method for p-value computation.
n_permutations_pval : int, optional, default=10000
The number of permutations for p-value calculation.
n_permutations_bootstrap : int, optional, default=10000
The number of permutations for bootstrap confidence intervals.
cov_regularization : str, optional, default="Empirical"
The method of covariance regularization to use. Must be one of:
Expand Down Expand Up @@ -209,17 +221,20 @@ def fit(
Parameters
----------
X : np.array-like, shape=(N, self.p)
The phenotype data
X : np.array-like, shape=(N, p)
The phenotype data.
y : np.array-like, shape=(N,)
The genotype data indicating the number of copies of the
minority allele
minority allele.
seed : int, optional, default=2021
Seed for the random number generator.
Returns
-------
self : POESingleSNP
The instance of the method
The instance of the method.
"""
self._test_inputs(X, y)
self.n_phenotypes = X.shape[1]
Expand Down Expand Up @@ -556,3 +571,190 @@ def _test_inputs(self, X: np.ndarray, Y: np.ndarray) -> None:
raise ValueError("y is numpy array")
if X.shape[0] != Y.shape[0]:
raise ValueError("X and Y have different sample sizes")


class POEMultipleSNP2(BasePOE):
""" """

def __init__(
self,
pval_method: str = "rmt4ds",
cov_regularization: str = "Empirical",
svd_loss: Optional[str] = None,
n_repeats: int = 4000,
) -> None:
"""
Raises
------
ValueError
If `cov_regularization` is not one of the allowable values.
"""
self.pval_method = pval_method
self.cov_regularization = cov_regularization
if cov_regularization == "Empirical":
self.cov_reg: Union[
EmpiricalCovariance,
NonLinearShrinkageCovariance,
LinearInverseShrinkage,
QuadraticInverseShrinkage,
] = EmpiricalCovariance()
elif cov_regularization == "NonLinear":
self.cov_reg = NonLinearShrinkageCovariance()
elif cov_regularization == "LinearInverse":
self.cov_reg = LinearInverseShrinkage()
elif cov_regularization == "QuadraticInverse":
self.cov_reg = QuadraticInverseShrinkage()
else:
raise ValueError(
"Invalid covariance regulator. Must be one of: Empirical, "
"NonLinear, LinearInverse, QuadraticInverse"
)
self.svd_loss = svd_loss
self.n_repeats = n_repeats
self.p_vals: np.ndarray = np.array([])
self.parent_effects_: np.ndarray = np.array([])

def fit(
self,
X: np.ndarray,
Y: np.ndarray,
seed: int = 2021,
) -> "POEMultipleSNP2":
"""
Fit the POEMultipleSNP2 model.
Parameters
----------
X : np.array-like, shape=(N, self.n_phenotypes)
The phenotype data
Y : np.array-like, shape=(N, self.n_genotypes)
The genotype data indicating the number of copies of the
minority allele
seed : int, optional, default=2021
Seed for the random number generator.
Returns
-------
self : POEMultipleSNP2
The instance of the method
"""
self._test_inputs(X, Y)
self.n_phenotypes = X.shape[1]
self.n_genotypes = Y.shape[1]

self.p_vals = -1 * np.ones(self.n_genotypes)
self.parent_effects_ = np.zeros((self.n_genotypes, self.n_phenotypes))

maf_vals = np.mean(Y > 0, axis=0)
maf_thresholds = [0.05, 0.01, 0.001]
maf_perms: Dict[float, np.ndarray] = {}

rng = np.random.default_rng(seed)
for maf in maf_thresholds:
perms = []
for _ in range(self.n_repeats):
n_total = X.shape[0]
homo_prob = (1 - maf) ** 2
homo_count = int(homo_prob * n_total)
het_count = n_total - homo_count

perm_indices = rng.permutation(n_total)
homo_indices = perm_indices[:homo_count]
het_indices = perm_indices[homo_count : homo_count + het_count]

X_homo = X[homo_indices]
X_het = X[het_indices]

X_homo = X_homo - np.mean(X_homo, axis=0)
X_het = X_het - np.mean(X_het, axis=0)
cov_reg = self.cov_reg
cov_reg.fit(X_homo)
Sigma_AA = np.array(cov_reg.covariance)
L = la.cholesky(Sigma_AA)
L_inv = la.inv(L)

X_het_whitened = np.dot(X_het, L_inv.T)
Sigma_AB_white = np.cov(X_het_whitened.T)

U, s, Vt = la.svd(Sigma_AB_white)

if self.svd_loss:
s, _ = optimal_shrinkage(
s, self.n_phenotypes / X_het.shape[0], self.svd_loss
)

norm_a = np.maximum(s[0] - 1, 0)
parent_effect_white = Vt[0] * 2 * np.sqrt(norm_a)
parent_effect = np.dot(parent_effect_white, L.T)
perms.append(np.linalg.norm(parent_effect))
maf_perms[maf] = np.array(perms)

for i in range(self.n_genotypes):
current_maf = maf_vals[i]
appropriate_threshold = max(
[t for t in maf_thresholds if t <= current_maf],
default=min(maf_thresholds),
)
relevant_perms = maf_perms[appropriate_threshold]

model = POESingleSNP(
compute_pvalue=False,
compute_ci=False,
cov_regularization=self.cov_regularization,
svd_loss=self.svd_loss,
)
model.fit(X, Y[:, i], seed=seed)
self.parent_effects_[i] = model.parent_effect_
norm_effect = np.linalg.norm(model.parent_effect_)

p_value = (relevant_perms >= norm_effect).mean()
self.p_vals[i] = p_value

return self

def transform(
self, X: np.ndarray, return_inner: bool = False
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
This method predicts whether the heterozygote allele came from
a maternal/paternal origin. Note that due to a lack of
identifiability, we can't state whether class 1 is paternal or
maternal
Parameters
----------
X : np.array-like, shape=(N, self.phenotypes)
The phenotype data
return_inner : bool, default=False
Whether to return the inner product classification, a measure
of confidence in the call
Returns
-------
calls : np.array-like, shape=(N,self.n_genotypes)
A vector of 1s and 0s predicting class
preds : np.array-like, shape=(N,self.n_genotypes)
The inner product, representing confidence in calls
"""
N = X.shape[0]
calls = np.zeros((N, self.n_genotypes))
preds = np.zeros((N, self.n_genotypes))
X_dm = X - np.mean(X, axis=0)
for i in range(self.n_genotypes):
preds[:, i] = np.dot(X_dm, self.parent_effects_[i])
calls[:, i] = 1.0 * (preds[:, i] > 0)
if return_inner is False:
return calls
return calls, preds

def _test_inputs(self, X: np.ndarray, Y: np.ndarray) -> None:
if not isinstance(X, np.ndarray):
raise ValueError("X is numpy array")
if not isinstance(Y, np.ndarray):
raise ValueError("y is numpy array")
if X.shape[0] != Y.shape[0]:
raise ValueError("X and Y have different sample sizes")
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from bystro.parent_of_origin.parent_of_origin import (
POESingleSNP,
POEMultipleSNP,
POEMultipleSNP2,
)


Expand Down Expand Up @@ -242,3 +243,19 @@ def test_multi_fit():
)
model = POEMultipleSNP()
model.fit(data["phenotypes"], data["genotypes"])


def test_multi2_fit():
np.set_printoptions(suppress=True)
rng = np.random.default_rng(2021)
n_p = 40
beta_m = np.zeros(n_p)
beta_p = np.zeros(n_p)
beta_p[:3] = 0.5
data = generate_multivariate_data(
beta_m, beta_p, rng, maf=0.03, n_individuals=50000, n_genotypes=1000
)
model = POEMultipleSNP2(n_repeats=10)
model.fit(data["phenotypes"], data["genotypes"], seed=2021)
assert model is not None, "Model fitting failed"
assert isinstance(model, POEMultipleSNP2), "Model type is incorrect"

0 comments on commit 8f1df25

Please sign in to comment.