diff --git a/CHANGELOG.md b/CHANGELOG.md index 228cad01e2..5942df1005 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,9 +28,11 @@ to [Semantic Versioning]. Full commit history is available in the validation set, if available. {pr}`3036`. - Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`. - Implemented variance of ZINB distribution. {pr}`3044`. +- MuData support for {class}`~scvi.model.MULTIVI` via the method + {meth}`~scvi.model.MULTIVI.setup_mudata` {pr}`3038`. - Add {class}`scvi.external.METHYLVI` for modeling methylation data from single-cell bisulfite sequencing (scBS-seq) experiments {pr}`2834`. - + #### Fixed - Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI` diff --git a/pyproject.toml b/pyproject.toml index 34c9d6b58b..8478eb7d71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ regseq = ["biopython>=1.81", "genomepy"] # read loom loompy = ["loompy>=3.0.6"] # scvi.criticism and read 10x -scanpy = ["scanpy>=1.6"] +scanpy = ["scanpy>=1.6","scikit-misc"] optional = [ "scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]" @@ -107,7 +107,6 @@ tutorials = [ "pooch", "pynndescent", "igraph", - "scikit-misc", "scrublet", "scib-metrics", "scvi-tools[optional]", diff --git a/src/scvi/_constants.py b/src/scvi/_constants.py index f565dc9f4a..ec6e4e914d 100644 --- a/src/scvi/_constants.py +++ b/src/scvi/_constants.py @@ -3,6 +3,7 @@ class _REGISTRY_KEYS_NT(NamedTuple): X_KEY: str = "X" + ATAC_X_KEY: str = "atac" BATCH_KEY: str = "batch" SAMPLE_KEY: str = "sample" LABELS_KEY: str = "labels" diff --git a/src/scvi/data/fields/_arraylike_field.py b/src/scvi/data/fields/_arraylike_field.py index 4e02751d8c..b4dafc9aed 100644 --- a/src/scvi/data/fields/_arraylike_field.py +++ b/src/scvi/data/fields/_arraylike_field.py @@ -212,6 +212,8 @@ class BaseJointField(BaseArrayLikeField): Sequence of keys to combine to form the obsm or varm field. field_type Type of field. Can be either 'obsm' or 'varm'. + required + If True, the field must be present in the AnnData object """ def __init__( @@ -219,6 +221,7 @@ def __init__( registry_key: str, attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, + required: bool = True, ) -> None: super().__init__(registry_key) if field_type == "obsm": @@ -232,6 +235,7 @@ def __init__( self._attr_key = f"_scvi_{registry_key}" self._attr_keys = attr_keys if attr_keys is not None else [] self._is_empty = len(self.attr_keys) == 0 + self._required = required def validate_field(self, adata: AnnData) -> None: """Validate the field.""" @@ -267,6 +271,10 @@ def attr_key(self) -> str: def is_empty(self) -> bool: return self._is_empty + @property + def required(self) -> bool: + return self._required + class NumericalJointField(BaseJointField): """An AnnDataField for a collection of numerical .obs or .var fields in AnnData. @@ -282,6 +290,8 @@ class NumericalJointField(BaseJointField): Sequence of keys to combine to form the obsm or varm field. field_type Type of field. Can be either 'obsm' or 'varm'. + required + If True, the field must be present in the AnnData object """ COLUMNS_KEY = "columns" @@ -291,8 +301,9 @@ def __init__( registry_key: str, attr_keys: list[str] | None, field_type: Literal["obsm", "varm"] = None, + required: bool = True, ) -> None: - super().__init__(registry_key, attr_keys, field_type=field_type) + super().__init__(registry_key, attr_keys, field_type=field_type, required=required) self.count_stat_key = f"n_{self.registry_key}" diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 85d307a25c..d9c9871169 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -9,11 +9,12 @@ import numpy as np import pandas as pd import torch +from mudata import MuData from scipy.sparse import csr_matrix, vstack from torch.distributions import Normal from scvi import REGISTRY_KEYS, settings -from scvi.data import AnnDataManager +from scvi.data import AnnDataManager, fields from scvi.data.fields import ( CategoricalJointObsField, CategoricalObsField, @@ -45,7 +46,7 @@ from anndata import AnnData - from scvi._types import Number + from scvi._types import AnnOrMuData, Number logger = logging.getLogger(__name__) @@ -59,7 +60,8 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi.model.MULTIVI.setup_anndata`. + AnnData/MuData object that has been registered via + :meth:`~scvi.model.MULTIVI.setup_anndata` or :meth:`~scvi.model.MULTIVI.setup_mudata`. n_genes The number of gene expression features (genes). n_regions @@ -116,13 +118,15 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): -------- >>> adata_rna = anndata.read_h5ad(path_to_rna_anndata) >>> adata_atac = scvi.data.read_10x_atac(path_to_atac_anndata) - >>> adata_multi = scvi.data.read_10x_multiome(path_to_multiomic_anndata) - >>> adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi, adata_rna, adata_atac) - >>> scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key="modality") - >>> vae = scvi.model.MULTIVI(adata_mvi) + >>> adata_protein = anndata.read_h5ad(path_to_protein_anndata) + >>> mdata = MuData({"rna": adata_rna, "protein": adata_protein, "atac": adata_atac}) + >>> scvi.model.MULTIVI.setup_mudata(mdata, batch_key="batch", + >>> modalities={"rna_layer": "rna", "protein_layer": "protein", "batch_key": "rna", + >>> "atac_layer": "atac"}) + >>> vae = scvi.model.MULTIVI(mdata) >>> vae.train() - Notes + Notes (for using setup_anndata) ----- * The model assumes that the features are organized so that all expression features are consecutive, followed by all accessibility features. For example, if the data has 100 genes @@ -140,9 +144,9 @@ class MULTIVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, ArchesMixin): def __init__( self, - adata: AnnData, - n_genes: int, - n_regions: int, + adata: AnnOrMuData, + n_genes: int | None = None, + n_regions: int | None = None, modality_weights: Literal["equal", "cell", "universal"] = "equal", modality_penalty: Literal["Jeffreys", "MMD", "None"] = "Jeffreys", n_hidden: int | None = None, @@ -164,6 +168,13 @@ def __init__( ): super().__init__(adata) + if n_genes is None or n_regions is None: + assert isinstance( + adata, MuData + ), "n_genes and n_regions must be provided if using AnnData" + n_genes = self.summary_stats.get("n_vars", 0) + n_regions = self.summary_stats.get("n_atac", 0) + prior_mean, prior_scale = None, None n_cats_per_cov = ( self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key @@ -359,7 +370,7 @@ def train( @torch.inference_mode() def get_library_size_factors( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, indices: Sequence[int] = None, batch_size: int = 128, ) -> dict[str, np.ndarray]: @@ -368,8 +379,8 @@ def get_library_size_factors( Parameters ---------- adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults + to the AnnOrMuData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. batch_size @@ -408,7 +419,7 @@ def get_region_factors(self) -> np.ndarray: @torch.inference_mode() def get_latent_representation( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, modality: Literal["joint", "expression", "accessibility"] = "joint", indices: Sequence[int] | None = None, give_mean: bool = True, @@ -419,8 +430,8 @@ def get_latent_representation( Parameters ---------- adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults + to the AnnOrMuData object used to initialize the model. modality Return modality specific or joint latent representation. indices @@ -478,7 +489,7 @@ def get_latent_representation( @torch.inference_mode() def get_accessibility_estimates( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, indices: Sequence[int] = None, n_samples_overall: int | None = None, region_list: Sequence[str] | None = None, @@ -499,8 +510,8 @@ def get_accessibility_estimates( Parameters ---------- adata - AnnData object that has been registered with scvi. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object that has been registered with scvi. If `None`, defaults to the + AnnOrMuData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. n_samples_overall @@ -543,7 +554,7 @@ def get_accessibility_estimates( if region_list is None: region_mask = slice(None) else: - region_mask = [region in region_list for region in adata.var_names[self.n_genes :]] + region_mask = [region in region_list for region in adata.var_names[: self.n_regions]] if threshold is not None and (threshold < 0 or threshold > 1): raise ValueError("the provided threshold must be between 0 and 1") @@ -576,25 +587,36 @@ def get_accessibility_estimates( else: # imputed is a list of tensors imputed = torch.cat(imputed).numpy() - if return_numpy: - return imputed - elif threshold: - return pd.DataFrame.sparse.from_spmatrix( - imputed, - index=adata.obs_names[indices], - columns=adata.var_names[self.n_genes :][region_mask], - ) - else: + if np.all(imputed is None): return pd.DataFrame( imputed, index=adata.obs_names[indices], - columns=adata.var_names[self.n_genes :][region_mask], + columns=[], ) + else: + if return_numpy: + return imputed + elif threshold: + return pd.DataFrame.sparse.from_spmatrix( + imputed, + index=adata.obs_names[indices], + columns=adata["rna"].var_names[: self.n_regions][region_mask] + if isinstance(adata, MuData) + else adata.var_names[: self.n_regions][region_mask], + ) + else: + return pd.DataFrame( + imputed, + index=adata.obs_names[indices], + columns=adata["rna"].var_names[: self.n_regions][region_mask] + if isinstance(adata, MuData) + else adata.var_names[: self.n_regions][region_mask], + ) @torch.inference_mode() def get_normalized_expression( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, indices: Sequence[int] | None = None, n_samples_overall: int | None = None, transform_batch: Sequence[Number | str] | None = None, @@ -612,8 +634,8 @@ def get_normalized_expression( Parameters ---------- adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. + AnnOrMuData object with equivalent structure to initial AnnData. If `None`, defaults + to the AnnOrMuData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. n_samples_overall @@ -776,7 +798,7 @@ def differential_accessibility( """ self._check_adata_modality_weights(adata) adata = self._validate_anndata(adata) - col_names = adata.var_names[self.n_genes :] + col_names = adata.var_names[: self.n_genes] model_fn = partial( self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size ) @@ -797,7 +819,7 @@ def m1_domain_fn(samples): all_stats_fn = partial( scatac_raw_counts_properties, - var_idx=np.arange(adata.shape[1])[self.n_genes :], + var_idx=np.arange(adata.shape[1])[: self.n_genes], ) result = _de_core( @@ -928,7 +950,7 @@ def differential_expression( @torch.no_grad() def get_protein_foreground_probability( self, - adata: AnnData | None = None, + adata: AnnOrMuData | None = None, indices: Sequence[int] | None = None, transform_batch: Sequence[Number | str] | None = None, protein_list: Sequence[str] | None = None, @@ -945,8 +967,8 @@ def get_protein_foreground_probability( Parameters ---------- adata - AnnData object with equivalent structure to initial AnnData. If ``None``, defaults to - the AnnData object used to initialize the model. + AnnOrMuData object with equivalent structure to initial AnnData. If ``None``, defaults + to the AnnOrMuData object used to initialize the model. indices Indices of cells in adata to use. If `None`, all cells are used. transform_batch @@ -1080,6 +1102,12 @@ def setup_anndata( `adata.obsm[protein_expression_obsm_key]` if it is a DataFrame, else will assign sequential names to proteins. """ + warnings.warn( + "MULTIVI is supposed to work with MuData. the use of anndata is " + "deprecated and will be removed in scvi-tools 1.4. Please use setup_mudata", + DeprecationWarning, + stacklevel=settings.warnings_stacklevel, + ) setup_method_args = cls._get_setup_method_args(**locals()) adata.obs["_indices"] = np.arange(adata.n_obs) batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) @@ -1088,7 +1116,7 @@ def setup_anndata( batch_field, CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, None), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), + NumericalJointObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False), CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), @@ -1117,3 +1145,126 @@ def _check_adata_modality_weights(self, adata): """ if (adata is not None) and (self.module.modality_weights == "cell"): raise RuntimeError("Held out data not permitted when using per cell weights") + + @classmethod + @setup_anndata_dsp.dedent + def setup_mudata( + cls, + mdata: MuData, + rna_layer: str | None = None, + atac_layer: str | None = None, + protein_layer: str | None = None, + batch_key: str | None = None, + size_factor_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, + continuous_covariate_keys: list[str] | None = None, + idx_layer: str | None = None, + modalities: dict[str, str] | None = None, + **kwargs, + ): + """%(summary_mdata)s. + + Parameters + ---------- + %(param_mdata)s + rna_layer + RNA layer key. If `None`, will use `.X` of specified modality key. + atac_layer + ATAC layer key. If `None`, will use `.X` of specified modality key. + protein_layer + Protein layer key. If `None`, will use `.X` of specified modality key. + %(param_batch_key)s + size_factor_key + Key in `mdata.obsm` for size factors. The first column corresponds to RNA size factors, + the second to ATAC size factors. + The second column need to be normalized and between 0 and 1. + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + %(idx_layer)s + %(param_modalities)s + + Examples + -------- + >>> mdata = muon.read_10x_h5("filtered_feature_bc_matrix.h5") + >>> scvi.model.MULTIVI.setup_mudata( + mdata, modalities={"rna_layer": "rna", "protein_layer": "atac"} + ) + >>> vae = scvi.model.MULTIVI(mdata) + """ + setup_method_args = cls._get_setup_method_args(**locals()) + + if modalities is None: + raise ValueError("Modalities cannot be None.") + modalities = cls._create_modalities_attr_dict(modalities, setup_method_args) + mdata.obs["_indices"] = np.arange(mdata.n_obs) + + batch_field = fields.MuDataCategoricalObsField( + REGISTRY_KEYS.BATCH_KEY, + batch_key, + mod_key=modalities.batch_key, + ) + mudata_fields = [ + batch_field, + fields.MuDataCategoricalObsField( + REGISTRY_KEYS.LABELS_KEY, + None, + mod_key=None, + ), + fields.MuDataNumericalJointObsField( + REGISTRY_KEYS.SIZE_FACTOR_KEY, + size_factor_key, + mod_key=None, + required=False, + ), + fields.MuDataCategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, + categorical_covariate_keys, + mod_key=modalities.categorical_covariate_keys, + ), + fields.MuDataNumericalJointObsField( + REGISTRY_KEYS.CONT_COVS_KEY, + continuous_covariate_keys, + mod_key=modalities.continuous_covariate_keys, + ), + fields.MuDataNumericalObsField( + REGISTRY_KEYS.INDICES_KEY, + "_indices", + mod_key=modalities.idx_layer, + required=False, + ), + ] + if modalities.rna_layer is not None: + mudata_fields.append( + fields.MuDataLayerField( + REGISTRY_KEYS.X_KEY, + rna_layer, + mod_key=modalities.rna_layer, + is_count_data=True, + mod_required=True, + ) + ) + if modalities.atac_layer is not None: + mudata_fields.append( + fields.MuDataLayerField( + REGISTRY_KEYS.ATAC_X_KEY, + atac_layer, + mod_key=modalities.atac_layer, + is_count_data=True, + mod_required=True, + ) + ) + if modalities.protein_layer is not None: + mudata_fields.append( + fields.MuDataProteinLayerField( + REGISTRY_KEYS.PROTEIN_EXP_KEY, + protein_layer, + mod_key=modalities.protein_layer, + use_batch_mask=True, + batch_field=batch_field, + is_count_data=True, + mod_required=True, + ) + ) + adata_manager = AnnDataManager(fields=mudata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(mdata, **kwargs) + cls.register_manager(adata_manager) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index a50c56e3ee..6eb7cc1a7a 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -35,7 +35,7 @@ from anndata import AnnData from mudata import MuData - from scvi._types import Number + from scvi._types import AnnOrMuData, Number logger = logging.getLogger(__name__) @@ -46,7 +46,8 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): Parameters ---------- adata - AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`. + AnnData/MuData object that has been registered via + :meth:`~scvi.model.TOTALVI.setup_anndata` or :meth:`~scvi.model.TOTALVI.setup_mudata`. n_latent Dimensionality of the latent space. gene_dispersion @@ -108,7 +109,7 @@ class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass): def __init__( self, - adata: AnnData, + adata: AnnOrMuData, n_latent: int = 20, gene_dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", protein_dispersion: Literal["protein", "protein-batch", "protein-label"] = "protein", @@ -1214,6 +1215,11 @@ def setup_anndata( ------- %(returns)s """ + warnings.warn( + "TOTALVI is supposed to work with MuData.", + DeprecationWarning, + stacklevel=settings.warnings_stacklevel, + ) setup_method_args = cls._get_setup_method_args(**locals()) batch_field = fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) anndata_fields = [ @@ -1275,7 +1281,7 @@ def setup_mudata( -------- >>> mdata = muon.read_10x_h5("pbmc_10k_protein_v3_filtered_feature_bc_matrix.h5") >>> scvi.model.TOTALVI.setup_mudata( - mdata, modalities={"rna_layer": "rna": "protein_layer": "prot"} + mdata, modalities={"rna_layer": "rna", "protein_layer": "prot"} ) >>> vae = scvi.model.TOTALVI(mdata) """ diff --git a/src/scvi/module/_multivae.py b/src/scvi/module/_multivae.py index 6ad24b65d2..e35c0d418d 100644 --- a/src/scvi/module/_multivae.py +++ b/src/scvi/module/_multivae.py @@ -259,8 +259,6 @@ class MULTIVAE(BaseModuleClass): RNA distribution. """ - # TODO: replace n_input_regions and n_input_genes with a gene/region mask (we don't dictate - # which comes first or that they're even contiguous) def __init__( self, n_input_regions: int = 0, @@ -301,7 +299,7 @@ def __init__( if n_input_regions == 0: self.n_hidden = np.min([128, int(np.sqrt(self.n_input_genes))]) else: - self.n_hidden = int(np.sqrt(self.n_input_regions)) + self.n_hidden = np.min([128, int(np.sqrt(self.n_input_regions))]) else: self.n_hidden = n_hidden self.n_batch = n_batch @@ -533,7 +531,12 @@ def __init__( def _get_inference_input(self, tensors): """Get input tensors for the inference model.""" - x = tensors[REGISTRY_KEYS.X_KEY] + x = tensors.get(REGISTRY_KEYS.X_KEY, None) + x_atac = tensors.get(REGISTRY_KEYS.ATAC_X_KEY, None) + if x is not None and x_atac is not None: + x = torch.cat((x, x_atac), dim=-1) + elif x is None: + x = x_atac if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: @@ -543,6 +546,7 @@ def _get_inference_input(self, tensors): cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY) cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY) label = tensors[REGISTRY_KEYS.LABELS_KEY] + size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None) input_dict = { "x": x, "y": y, @@ -551,6 +555,7 @@ def _get_inference_input(self, tensors): "cat_covs": cat_covs, "label": label, "cell_idx": cell_idx, + "size_factor": size_factor, } return input_dict @@ -564,6 +569,7 @@ def inference( cat_covs, label, cell_idx, + size_factor, n_samples=1, ) -> dict[str, torch.Tensor]: """Run the inference model.""" @@ -573,21 +579,21 @@ def inference( else: x_rna = x[:, : self.n_input_genes] if self.n_input_regions == 0: - x_chr = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) + x_atac = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: - x_chr = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] mask_expr = x_rna.sum(dim=1) > 0 - mask_acc = x_chr.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 mask_pro = y.sum(dim=1) > 0 if cont_covs is not None and self.encode_covariates: encoder_input_expression = torch.cat((x_rna, cont_covs), dim=-1) - encoder_input_accessibility = torch.cat((x_chr, cont_covs), dim=-1) + encoder_input_accessibility = torch.cat((x_atac, cont_covs), dim=-1) encoder_input_protein = torch.cat((y, cont_covs), dim=-1) else: encoder_input_expression = x_rna - encoder_input_accessibility = x_chr + encoder_input_accessibility = x_atac encoder_input_protein = y if cat_covs is not None and self.encode_covariates: @@ -607,12 +613,16 @@ def inference( ) # L encoders - libsize_expr = self.l_encoder_expression( - encoder_input_expression, batch_index, *categorical_input - ) - libsize_acc = self.l_encoder_accessibility( - encoder_input_accessibility, batch_index, *categorical_input - ) + if self.use_size_factor_key: + libsize_expr = torch.log(size_factor[:, [0]] + 1e-6) + libsize_acc = size_factor[:, [1]] + else: + libsize_expr = self.l_encoder_expression( + encoder_input_expression, batch_index, *categorical_input + ) + libsize_acc = self.l_encoder_accessibility( + encoder_input_accessibility, batch_index, *categorical_input + ) # mix representations if self.modality_weights == "cell": @@ -651,6 +661,7 @@ def unsqz(zt, n_s): z = self.z_encoder_accessibility.z_transformation(untran_z) outputs = { + "x": x, "z": z, "qz_m": qz_m, "qz_v": qz_v, @@ -674,11 +685,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None qz_m = inference_outputs["qz_m"] libsize_expr = inference_outputs["libsize_expr"] - size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY - size_factor = ( - torch.log(tensors[size_factor_key]) if size_factor_key in tensors.keys() else None - ) - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] cont_key = REGISTRY_KEYS.CONT_COVS_KEY cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None @@ -698,7 +704,6 @@ def _get_generative_input(self, tensors, inference_outputs, transform_batch=None "cont_covs": cont_covs, "cat_covs": cat_covs, "libsize_expr": libsize_expr, - "size_factor": size_factor, "label": label, } return input_dict @@ -712,7 +717,6 @@ def generative( cont_covs=None, cat_covs=None, libsize_expr=None, - size_factor=None, use_z_mean=False, label: torch.Tensor = None, ): @@ -736,12 +740,10 @@ def generative( p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input) # Expression Decoder - if not self.use_size_factor_key: - size_factor = libsize_expr px_scale, _, px_rate, px_dropout = self.z_decoder_expression( self.gene_dispersion, decoder_input, - size_factor, + libsize_expr, batch_index, *categorical_input, label, @@ -783,24 +785,23 @@ def generative( def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): """Computes the loss function for the model.""" # Get the data - x = tensors[REGISTRY_KEYS.X_KEY] + x = inference_outputs["x"] - # TODO: CHECK IF THIS FAILS IN ONLY RNA DATA x_rna = x[:, : self.n_input_genes] - x_chr = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] + x_atac = x[:, self.n_input_genes : (self.n_input_genes + self.n_input_regions)] if self.n_input_proteins == 0: y = torch.zeros(x.shape[0], 1, device=x.device, requires_grad=False) else: y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY] mask_expr = x_rna.sum(dim=1) > 0 - mask_acc = x_chr.sum(dim=1) > 0 + mask_acc = x_atac.sum(dim=1) > 0 mask_pro = y.sum(dim=1) > 0 # Compute Accessibility loss p = generative_outputs["p"] libsize_acc = inference_outputs["libsize_acc"] - rl_accessibility = self.get_reconstruction_loss_accessibility(x_chr, p, libsize_acc) + rl_accessibility = self.get_reconstruction_loss_accessibility(x_atac, p, libsize_acc) # Compute Expression loss px_rate = generative_outputs["px_rate"] @@ -819,7 +820,6 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float rl_protein = torch.zeros(x.shape[0], device=x.device, requires_grad=False) # calling without weights makes this act like a masked sum - # TODO : CHECK MIXING HERE recon_loss_expression = rl_expression * mask_expr recon_loss_accessibility = rl_accessibility * mask_acc recon_loss_protein = rl_protein * mask_pro diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index bdeebba550..210baec931 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -1,8 +1,16 @@ +import os + +import anndata as ad import numpy as np import pytest +import scanpy as sc +from mudata import MuData +import scvi +from scvi import REGISTRY_KEYS from scvi.data import synthetic_iid from scvi.model import MULTIVI +from scvi.utils import attrdict def test_multivi(): @@ -34,7 +42,7 @@ def test_multivi(): # Test with size factor data = synthetic_iid() data.obs["size_factor"] = np.random.randint(1, 5, size=(data.shape[0],)) - MULTIVI.setup_anndata(data, batch_key="batch", size_factor_key="size_factor") + MULTIVI.setup_anndata(data, batch_key="batch") vae = MULTIVI( data, n_genes=50, @@ -76,3 +84,415 @@ def test_multivi_single_batch(): vae = MULTIVI(data, n_genes=50, n_regions=50) with pytest.warns(UserWarning): vae.train(3) + + +def test_multivi_mudata_rna_prot_external(): + # Example on how to download protein adata to mudata (from multivi tutorial) - mudata RNA/PROT + adata = scvi.data.pbmcs_10x_cite_seq() + adata.layers["counts"] = adata.X.copy() + adata.obs_names_make_unique() + protein_adata = ad.AnnData(adata.obsm["protein_expression"]) + protein_adata.obs_names = adata.obs_names + del adata.obsm["protein_expression"] + mdata = MuData({"rna": adata, "protein": protein_adata}) + sc.pp.highly_variable_genes( + mdata.mod["rna"], + n_top_genes=4000, + flavor="seurat_v3", + batch_key="batch", + layer="counts", + ) + mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy() + mdata.update() + # mdata + # mdata.mod + MULTIVI.setup_mudata( + mdata, + rna_layer="counts", # mean we use: mdata.mod["rna_subset"].layers["counts"] + protein_layer=None, # mean we use: mdata.mod["protein"].X + batch_key="batch", # the batch is here: mdata.mod["rna_subset"].obs["batch"] + modalities={ + "rna_layer": "rna_subset", + "protein_layer": "protein", + "batch_key": "rna_subset", + }, + ) + model = MULTIVI(mdata) + model.train(1, train_size=0.9) + + +def test_multivi_mudata_rna_atac_external(): + # optional data - mudata RNA/ATAC + mdata = synthetic_iid(return_mudata=True) + sc.pp.highly_variable_genes( + mdata.mod["rna"], + n_top_genes=4000, + flavor="seurat_v3", + ) + mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy() + sc.pp.highly_variable_genes( + mdata.mod["accessibility"], + n_top_genes=4000, + flavor="seurat_v3", + ) + mdata.mod["atac_subset"] = mdata.mod["accessibility"][ + :, mdata.mod["accessibility"].var["highly_variable"] + ].copy() + mdata.update() + MULTIVI.setup_mudata( + mdata, + modalities={"rna_layer": "rna_subset", "atac_layer": "atac_subset"}, + ) + model = MULTIVI(mdata) + model.train(1, train_size=0.9) + + +def test_multivi_mudata_trimodal_external(): + # optional data - mudata RNA/ATAC + mdata = synthetic_iid(return_mudata=True) + MULTIVI.setup_mudata( + mdata, + modalities={ + "rna_layer": "rna", + "atac_layer": "accessibility", + "protein_layer": "protein_expression", + }, + ) + model = MULTIVI(mdata) + model.train(1, train_size=0.9) + model.train(1, train_size=0.9) + assert model.is_trained is True + model.get_latent_representation() + model.get_elbo() + model.get_reconstruction_error() + model.get_normalized_expression() + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + model.get_elbo(indices=model.validation_indices) + model.get_reconstruction_error(indices=model.validation_indices) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +@pytest.mark.parametrize("n_genes", [25, 50, 100]) +@pytest.mark.parametrize("n_regions", [25, 50, 100]) +def test_multivi_mudata(n_genes: int, n_regions: int): + # use of syntetic data of rna/proteins/atac for speed + + mdata = synthetic_iid(return_mudata=True) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={ + "rna_layer": "rna", + "protein_layer": "protein_expression", + "atac_layer": "accessibility", + }, + ) + n_obs = mdata.n_obs + n_latent = 10 + + model = MULTIVI(mdata, n_latent=n_latent, n_genes=n_genes, n_regions=n_regions) + model.train(1, train_size=0.9) + assert model.is_trained is True + z = model.get_latent_representation() + assert z.shape == (n_obs, n_latent) + model.get_elbo() + model.get_reconstruction_error() + model.get_normalized_expression() + model.get_normalized_expression(transform_batch=["batch_0", "batch_1"]) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + model.get_elbo(indices=model.validation_indices) + model.get_reconstruction_error(indices=model.validation_indices) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + mdata2 = synthetic_iid(return_mudata=True) + MULTIVI.setup_mudata( + mdata2, + batch_key="batch", + modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, + ) + norm_exp = model.get_normalized_expression(mdata2, indices=[1, 2, 3]) + assert norm_exp.shape == (3, n_genes) + + # test transfer_anndata_setup + view + mdata3 = synthetic_iid(return_mudata=True) + mdata3.obs["_indices"] = np.arange(mdata3.n_obs) + model.get_elbo(mdata3[:10]) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_auto_transfer_mudata(): + # test automatic transfer_fields + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata) + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + mdata2.obs["_indices"] = np.arange(mdata2.n_obs) + model.get_elbo(mdata2) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_incorrect_mapping_mudata(): + # test that we catch incorrect mappings + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata) + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + adata2.obs.batch = adata2.obs.batch.cat.rename_categories(["batch_0", "batch_10"]) + with pytest.raises(ValueError): + model.get_elbo(mdata2) + + +def test_multivi_reordered_mapping_mudata(): + # test that same mapping different order is okay + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata) + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + adata2.obs.batch = adata2.obs.batch.cat.rename_categories(["batch_1", "batch_0"]) + mdata2.obs["_indices"] = np.arange(mdata2.n_obs) + model.get_elbo(mdata2) + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_model_library_size_mudata(): + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + + n_latent = 10 + model = MULTIVI(mdata, n_latent=n_latent) + model.train(1, train_size=0.5) + assert model.is_trained is True + model.get_elbo() + model.get_accessibility_estimates() + model.get_accessibility_estimates(normalize_cells=True) + model.get_accessibility_estimates(normalize_regions=True) + model.get_library_size_factors() + model.get_region_factors() + + +def test_multivi_size_factor_mudata(): + mdata = synthetic_iid(return_mudata=True) + mdata.obs["size_factor_rna"] = mdata["rna"].X.sum(1) + mdata.obs["size_factor_atac"] = (mdata["accessibility"].X.sum(1) + 1) / ( + np.max(mdata["accessibility"].X.sum(1)) + 1.01 + ) + MULTIVI.setup_mudata( + mdata, + modalities={"rna_layer": "rna", "atac_layer": "accessibility"}, + size_factor_key=["size_factor_rna", "size_factor_atac"], + ) + + n_latent = 10 + + # Test size_factor_key overrides use_observed_lib_size. + model = MULTIVI(mdata, n_latent=n_latent) + assert model.module.use_size_factor_key + model.train(1, train_size=0.5) + + model = MULTIVI(mdata, n_latent=n_latent) + assert model.module.use_size_factor_key + model.train(1, train_size=0.5) + + +def test_multivi_saving_and_loading_mudata(save_path: str): + adata = synthetic_iid() + protein_adata = synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata) + model.train(1, train_size=0.2) + z1 = model.get_latent_representation(mdata) + test_idx1 = model.validation_indices + + model.save(save_path, overwrite=True, save_anndata=True) + model.view_setup_args(save_path) + + model = MULTIVI.load(save_path) + model.get_latent_representation() + + # Load with mismatched genes. + tmp_adata = synthetic_iid( + n_genes=200, + ) + tmp_protein_adata = synthetic_iid(n_genes=50) + tmp_mdata = MuData({"rna": tmp_adata, "protein": tmp_protein_adata}) + with pytest.raises(ValueError): + MULTIVI.load(save_path, adata=tmp_mdata) + + # Load with different batches. + tmp_adata = synthetic_iid() + tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories(["batch_2", "batch_3"]) + tmp_protein_adata = synthetic_iid(n_genes=50) + tmp_mdata = MuData({"rna": tmp_adata, "protein": tmp_protein_adata}) + with pytest.raises(ValueError): + MULTIVI.load(save_path, adata=tmp_mdata) + + model = MULTIVI.load(save_path, adata=mdata) + assert REGISTRY_KEYS.BATCH_KEY in model.adata_manager.data_registry + assert model.adata_manager.data_registry.batch == attrdict( + {"mod_key": "rna", "attr_name": "obs", "attr_key": "_scvi_batch"} + ) + + z2 = model.get_latent_representation() + test_idx2 = model.validation_indices + np.testing.assert_array_equal(z1, z2) + np.testing.assert_array_equal(test_idx1, test_idx2) + assert model.is_trained is True + + save_path = os.path.join(save_path, "tmp") + + adata2 = synthetic_iid() + protein_adata2 = synthetic_iid(n_genes=50) + mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) + MULTIVI.setup_mudata( + mdata2, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + + +def test_scarches_mudata_prep_layer(save_path: str): + n_latent = 5 + mdata1 = synthetic_iid(return_mudata=True) + + mdata1["rna"].layers["counts"] = mdata1["rna"].X.copy() + MULTIVI.setup_mudata( + mdata1, + batch_key="batch", + modalities={"rna_layer": "rna", "protein_layer": "protein_expression"}, + ) + model = MULTIVI(mdata1, n_latent=n_latent) + model.train(1, check_val_every_n_epoch=1) + dir_path = os.path.join(save_path, "saved_model/") + model.save(dir_path, overwrite=True) + + # mdata2 has more genes and missing 10 genes from mdata1. + # protein/acessibility features are same as in mdata1 + mdata2 = synthetic_iid(n_genes=110, return_mudata=True) + mdata2["rna"].layers["counts"] = mdata2["rna"].X.copy() + new_var_names_init = [f"Random {i}" for i in range(10)] + new_var_names = new_var_names_init + mdata2["rna"].var_names[10:].to_list() + mdata2["rna"].var_names = new_var_names + + original_protein_values = mdata2["protein_expression"].X.copy() + original_accessibility_values = mdata2["accessibility"].X.copy() + + MULTIVI.prepare_query_mudata(mdata2, dir_path) + # should be padded 0s + assert np.sum(mdata2["rna"][:, mdata2["rna"].var_names[:10]].layers["counts"]) == 0 + np.testing.assert_equal( + mdata2["rna"].var_names[:10].to_numpy(), mdata1["rna"].var_names[:10].to_numpy() + ) + + # values of other modalities should be unchanged + np.testing.assert_equal(original_protein_values, mdata2["protein_expression"].X) + np.testing.assert_equal(original_accessibility_values, mdata2["accessibility"].X) + + # and names should also be the same + np.testing.assert_equal( + mdata2["protein_expression"].var_names.to_numpy(), + mdata1["protein_expression"].var_names.to_numpy(), + ) + np.testing.assert_equal( + mdata2["accessibility"].var_names.to_numpy(), mdata1["accessibility"].var_names.to_numpy() + ) + MULTIVI.load_query_data(mdata2, dir_path) + + +def test_multivi_save_load_mudata_format(save_path: str): + mdata = synthetic_iid(return_mudata=True, protein_expression_key="protein") + invalid_mdata = mdata.copy() + invalid_mdata.mod["protein"] = invalid_mdata.mod["protein"][:, :10].copy() + MULTIVI.setup_mudata( + mdata, + modalities={"rna_layer": "rna", "protein_layer": "protein"}, + ) + model = MULTIVI(mdata) + model.train(max_epochs=1) + + legacy_model_path = os.path.join(save_path, "legacy_model") + model.save( + legacy_model_path, + overwrite=True, + save_anndata=False, + legacy_mudata_format=True, + ) + + with pytest.raises(ValueError): + _ = MULTIVI.load(legacy_model_path, adata=invalid_mdata) + model = MULTIVI.load(legacy_model_path, adata=mdata) + + model_path = os.path.join(save_path, "model") + model.save( + model_path, + overwrite=True, + save_anndata=False, + legacy_mudata_format=False, + ) + with pytest.raises(ValueError): + _ = MULTIVI.load(legacy_model_path, adata=invalid_mdata) + model = MULTIVI.load(model_path, adata=mdata)