Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Mudata support for MultiVI #3038

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d0ad5f5
Added mudata support for MULTIVI as well as tests
ori-kron-wis Nov 6, 2024
9450546
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
079faff
needed muon
ori-kron-wis Nov 6, 2024
5e8bc5f
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 6, 2024
b420037
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
815555c
Added ATC/PROTEIN + RNA capability for MultiVI + more tests like in t…
ori-kron-wis Nov 7, 2024
e03b006
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
ea59fd1
small fix
ori-kron-wis Nov 7, 2024
278353f
Merge remote-tracking branch 'origin/Ori-MultiVI-MuData' into Ori-Mul…
ori-kron-wis Nov 7, 2024
9118af8
small fix
ori-kron-wis Nov 7, 2024
870b0fc
small fix
ori-kron-wis Nov 7, 2024
4f08c1b
small fix
ori-kron-wis Nov 7, 2024
9622f88
small fix
ori-kron-wis Nov 7, 2024
987ebd3
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 13, 2024
371ef7a
fixed comments
ori-kron-wis Nov 13, 2024
3cf0108
fixed comments
ori-kron-wis Nov 13, 2024
ff82b28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
8326895
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 13, 2024
e670511
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 14, 2024
b6d2028
fixed typos
ori-kron-wis Nov 14, 2024
a903359
fix comments
ori-kron-wis Nov 18, 2024
c957c94
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 19, 2024
7b5f22f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
fe12fac
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 19, 2024
2a99751
added atac registry field
ori-kron-wis Nov 19, 2024
101641a
following can's fixes
ori-kron-wis Nov 20, 2024
e369eb1
Merge remote-tracking branch 'origin/main' into Ori-MultiVI-MuData
ori-kron-wis Nov 20, 2024
a0cd0bd
fix get_accessibility was using gene indices, should have used region…
ori-kron-wis Nov 20, 2024
711ec54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Starting from version 0.20.1, this format is based on [Keep a Changelog], and th
to [Semantic Versioning]. Full commit history is available in the
[commit logs](https://github.com/scverse/scvi-tools/commits/).

## Version 1.2
## Version 1.3

ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
### 1.3.0 (2024-XX-XX)

Expand All @@ -19,10 +19,15 @@ to [Semantic Versioning]. Full commit history is available in the
- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable
representation learning in single-cell RNA sequencing data {pr}`3015`.

## Version 1.2

### 1.2.1 (2024-XX-XX)

#### Added

- MuData support for {class}`~scvi.model.MULTIVI` via the method
{meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`.

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ regseq = ["biopython>=1.81", "genomepy"]
# read loom
loompy = ["loompy>=3.0.6"]
# scvi.criticism and read 10x
scanpy = ["scanpy>=1.6"]
scanpy = ["scanpy>=1.6","scikit-misc"]

optional = [
"scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]"
Expand All @@ -107,7 +107,6 @@ tutorials = [
"pooch",
"pynndescent",
"igraph",
"scikit-misc",
"scrublet",
"scib-metrics",
"scvi-tools[optional]",
Expand Down
182 changes: 157 additions & 25 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import numpy as np
import pandas as pd
import torch
from mudata import MuData
from scipy.sparse import csr_matrix, vstack
from torch.distributions import Normal

from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data import AnnDataManager, fields
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
Expand Down Expand Up @@ -45,7 +46,7 @@

from anndata import AnnData

from scvi._types import Number
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)

Expand All @@ -59,7 +60,8 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.MULTIVI.setup_anndata`.
AnnData/MuData object that has been registered via
:meth:`~scvi.model.MULTIVI.setup_anndata` or :meth:`~scvi.model.MULTIVI.setup_mudata`.
n_genes
The number of gene expression features (genes).
n_regions
Expand Down Expand Up @@ -116,13 +118,15 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):
--------
>>> adata_rna = anndata.read_h5ad(path_to_rna_anndata)
>>> adata_atac = scvi.data.read_10x_atac(path_to_atac_anndata)
>>> adata_multi = scvi.data.read_10x_multiome(path_to_multiomic_anndata)
>>> adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna, adata_atac)
>>> scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key="modality")
>>> vae = scvi.model.MULTIVI(adata_mvi)
>>> adata_protein = anndata.read_h5ad(path_to_protein_anndata)
>>> mdata = MuData({"rna": adata_rna, "protein": adata_protein, "atac": adata_atac})
>>> scvi.model.MULTIVI.setup_mudata(mdata, batch_key="batch",
>>> modalities={"rna_layer": "rna", "protein_layer": "protein", "batch_key": "rna",
>>> "atac_layer": "atac"})
>>> vae = scvi.model.MULTIVI(mdata)
>>> vae.train()

Notes
Notes (for using setup_anndata)
-----
* The model assumes that the features are organized so that all expression features are
consecutive, followed by all accessibility features. For example, if the data has 100 genes
Expand All @@ -140,7 +144,7 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin):

def __init__(
self,
adata: AnnData,
adata: AnnOrMuData,
n_genes: int,
n_regions: int,
modality_weights: Literal["equal", "cell", "universal"] = "equal",
Expand Down Expand Up @@ -359,7 +363,7 @@ def train(
@torch.inference_mode()
def get_library_size_factors(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] = None,
batch_size: int = 128,
) -> dict[str, np.ndarray]:
Expand All @@ -368,8 +372,8 @@ def get_library_size_factors(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Expand Down Expand Up @@ -408,7 +412,7 @@ def get_region_factors(self) -> np.ndarray:
@torch.inference_mode()
def get_latent_representation(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
modality: Literal["joint", "expression", "accessibility"] = "joint",
indices: Sequence[int] | None = None,
give_mean: bool = True,
Expand All @@ -419,8 +423,8 @@ def get_latent_representation(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
modality
Return modality specific or joint latent representation.
indices
Expand Down Expand Up @@ -478,7 +482,7 @@ def get_latent_representation(
@torch.inference_mode()
def get_accessibility_estimates(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] = None,
n_samples_overall: int | None = None,
region_list: Sequence[str] | None = None,
Expand All @@ -499,8 +503,8 @@ def get_accessibility_estimates(
Parameters
----------
adata
AnnData object that has been registered with scvi. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object that has been registered with scvi. If `None`, defaults to the
AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples_overall
Expand Down Expand Up @@ -588,13 +592,15 @@ def get_accessibility_estimates(
return pd.DataFrame(
imputed,
index=adata.obs_names[indices],
columns=adata.var_names[self.n_genes :][region_mask],
columns=adata["rna"].var_names[self.n_genes :][region_mask]
if isinstance(adata, MuData)
else adata.var_names[self.n_genes :][region_mask],
)

@torch.inference_mode()
def get_normalized_expression(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] | None = None,
n_samples_overall: int | None = None,
transform_batch: Sequence[Number | str] | None = None,
Expand All @@ -612,8 +618,8 @@ def get_normalized_expression(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples_overall
Expand Down Expand Up @@ -928,7 +934,7 @@ def differential_expression(
@torch.no_grad()
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
def get_protein_foreground_probability(
self,
adata: AnnData | None = None,
adata: AnnOrMuData | None = None,
indices: Sequence[int] | None = None,
transform_batch: Sequence[Number | str] | None = None,
protein_list: Sequence[str] | None = None,
Expand All @@ -945,8 +951,8 @@ def get_protein_foreground_probability(
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If ``None``, defaults to
the AnnData object used to initialize the model.
AnnOrMuData object with equivalent structure to initial AnnData. If ``None``, defaults
to the AnnOrMuData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
transform_batch
Expand Down Expand Up @@ -1080,6 +1086,12 @@ def setup_anndata(
`adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign
sequential names to proteins.
"""
warnings.warn(
"MULTIVI is supposed to work with MuData. the use of anndata is "
"deprecated and will be removed in scvi-tools 1.4. Please use setup_mudata",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
adata.obs["_indices"] = np.arange(adata.n_obs)
batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
Expand Down Expand Up @@ -1117,3 +1129,123 @@ def _check_adata_modality_weights(self, adata):
"""
if (adata is not None) and (self.module.modality_weights == "cell"):
raise RuntimeError("Held out data not permitted when using per cell weights")

@classmethod
@setup_anndata_dsp.dedent
def setup_mudata(
ori-kron-wis marked this conversation as resolved.
Show resolved Hide resolved
cls,
mdata: MuData,
rna_layer: str | None = None,
atac_layer: str | None = None,
protein_layer: str | None = None,
batch_key: str | None = None,
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
idx_layer: str | None = None,
modalities: dict[str, str] | None = None,
**kwargs,
):
"""%(summary_mdata)s.

Parameters
----------
%(param_mdata)s
rna_layer
RNA layer key. If `None`, will use `.X` of specified modality key.
protein_layer
Protein layer key. If `None`, will use `.X` of specified modality key.
atac_layer
ATAC layer key. If `None`, will use `.X` of specified modality key.
%(param_batch_key)s
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
%(idx_layer)s
%(param_modalities)s

Examples
--------
>>> mdata = muon.read_10x_h5("filtered_feature_bc_matrix.h5")
>>> scvi.model.MULTIVI.setup_mudata(
mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"}
)
>>> vae = scvi.model.MULTIVI(mdata)
"""
setup_method_args = cls._get_setup_method_args(**locals())

if modalities is None:
raise ValueError("Modalities cannot be None.")
modalities = cls._create_modalities_attr_dict(modalities, setup_method_args)
mdata.obs["_indices"] = np.arange(mdata.n_obs)

batch_field = fields.MuDataCategoricalObsField(
REGISTRY_KEYS.BATCH_KEY,
batch_key,
mod_key=modalities.batch_key,
)
mudata_fields = [
batch_field,
fields.MuDataCategoricalObsField(
REGISTRY_KEYS.LABELS_KEY,
None,
mod_key=None,
),
fields.MuDataNumericalObsField(
REGISTRY_KEYS.SIZE_FACTOR_KEY,
size_factor_key,
mod_key=modalities.size_factor_key,
required=False,
),
fields.MuDataCategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY,
categorical_covariate_keys,
mod_key=modalities.categorical_covariate_keys,
),
fields.MuDataNumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY,
continuous_covariate_keys,
mod_key=modalities.continuous_covariate_keys,
),
fields.MuDataNumericalObsField(
REGISTRY_KEYS.INDICES_KEY,
"_indices",
mod_key=modalities.idx_layer,
required=False,
),
]
if modalities.rna_layer is not None:
mudata_fields.append(
fields.MuDataLayerField(
REGISTRY_KEYS.X_KEY,
rna_layer,
mod_key=modalities.rna_layer,
is_count_data=True,
mod_required=True,
)
)
if modalities.atac_layer is not None:
mudata_fields.append(
fields.MuDataLayerField(
REGISTRY_KEYS.X_KEY,
atac_layer,
mod_key=modalities.atac_layer,
is_count_data=True,
mod_required=True,
)
)
if modalities.protein_layer is not None:
mudata_fields.append(
fields.MuDataProteinLayerField(
REGISTRY_KEYS.PROTEIN_EXP_KEY,
protein_layer,
mod_key=modalities.protein_layer,
use_batch_mask=True,
batch_field=batch_field,
is_count_data=True,
mod_required=True,
)
)
adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(mdata, **kwargs)
cls.register_manager(adata_manager)
14 changes: 10 additions & 4 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from anndata import AnnData
from mudata import MuData

from scvi._types import Number
from scvi._types import AnnOrMuData, Number

logger = logging.getLogger(__name__)

Expand All @@ -46,7 +46,8 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`.
AnnData/MuData object that has been registered via
:meth:`~scvi.model.TOTALVI.setup_anndata` or :meth:`~scvi.model.TOTALVI.setup_mudata`.
n_latent
Dimensionality of the latent space.
gene_dispersion
Expand Down Expand Up @@ -108,7 +109,7 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):

def __init__(
self,
adata: AnnData,
adata: AnnOrMuData,
n_latent: int = 20,
gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein",
Expand Down Expand Up @@ -1214,6 +1215,11 @@ def setup_anndata(
-------
%(returns)s
"""
warnings.warn(
"TOTALVI is supposed to work with MuData.",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
setup_method_args = cls._get_setup_method_args(**locals())
batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
anndata_fields = [
Expand Down Expand Up @@ -1275,7 +1281,7 @@ def setup_mudata(
--------
>>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5")
>>> scvi.model.TOTALVI.setup_mudata(
mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"}
mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"}
)
>>> vae = scvi.model.TOTALVI(mdata)
"""
Expand Down
Loading