diff --git a/sfaira/data/base/dataset.py b/sfaira/data/base/dataset.py index 79cdba1bf..cc335a10b 100644 --- a/sfaira/data/base/dataset.py +++ b/sfaira/data/base/dataset.py @@ -674,7 +674,8 @@ def streamline_metadata( :param schema: Export format. - "sfaira" - "cellxgene" - :param uns_to_obs: Whether to move metadata in .uns to .obs to make sure it's not lost when concatenating multiple objects. + :param uns_to_obs: Whether to move metadata in .uns to .obs to make sure it's not lost when concatenating + multiple objects. Retains .id in .uns. :param clean_obs: Whether to delete non-streamlined fields in .obs, .obsm and .obsp. :param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp. :param clean_uns: Whether to delete non-streamlined fields in .uns. @@ -909,7 +910,9 @@ def streamline_metadata( for k, v in self.adata.uns.items(): if k not in self.adata.obs_keys(): self.adata.obs[k] = [v for i in range(self.adata.n_obs)] - self.adata.uns = {} + # Retain only target uns keys in .uns. + self.adata.uns = dict([(k, v) for k, v in self.adata.uns.items() + if k in [getattr(adata_target_ids, kk) for kk in ["id"]]]) self._adata_ids = adata_target_ids # set new adata fields to class after conversion self.streamlined_meta = True @@ -948,6 +951,7 @@ def write_distributed_store( f"data, found {type(self.adata.X)}") fn = os.path.join(dir_cache, self.doi_cleaned_id + ".h5ad") as_dense = ("X",) if dense else () + print(f"writing {self.adata.shape} into {fn}") self.adata.write_h5ad(filename=fn, as_dense=as_dense, **compression_kwargs) elif store == "zarr": fn = os.path.join(dir_cache, self.doi_cleaned_id) diff --git a/sfaira/data/base/distributed_store.py b/sfaira/data/base/distributed_store.py index 19ca1a14c..7d6796e63 100644 --- a/sfaira/data/base/distributed_store.py +++ b/sfaira/data/base/distributed_store.py @@ -8,6 +8,7 @@ from sfaira.consts import AdataIdsSfaira, OCS from sfaira.data.base.dataset import is_child, UNS_STRING_META_IN_OBS +from sfaira.versions.genomes import GenomeContainer from sfaira.versions.metadata import CelltypeUniverse @@ -20,7 +21,7 @@ class DistributedStore: indices: Dict[str, np.ndarray] - def __init__(self, cache_path: Union[str, None] = None): + def __init__(self, cache_path: Union[str, os.PathLike, None] = None): """ This class is instantiated on a cache directory which contains pre-processed files in rapid access format. @@ -30,6 +31,7 @@ def __init__(self, cache_path: Union[str, None] = None): - zarr :param cache_path: Directory in which pre-processed .h5ad files lie. + :param genome_container: GenomeContainer with target features space defined. """ # Collect all data loaders from files in directory: adatas = {} @@ -53,72 +55,122 @@ def __init__(self, cache_path: Union[str, None] = None): self.adatas = adatas self.indices = indices self.ontology_container = OCS + self._genome_container = None self._adata_ids_sfaira = AdataIdsSfaira() self._celltype_universe = None @property def adata(self): - return list(self.adatas.values)[0].concatenate( - list(self.adatas.values)[1:] + return self.adatas[list(self.adatas.keys())[0]].concatenate( + *[self.adatas[k] for k in list(self.adatas.keys())[1:]], + batch_key="dataset_id", + batch_categories=list(self.adatas.keys()), ) + @property + def genome_container(self) -> Union[GenomeContainer, None]: + return self._genome_container + + @genome_container.setter + def genome_container(self, x: GenomeContainer): + var_names = self.__validate_feature_space_homogeneity() + # Validate genome container choice: + # Make sure that all var names defined in genome container are also contained in loaded data sets. + assert np.all([y in var_names for y in x.ensembl]), \ + "did not find variable names from genome container in store" + self._genome_container = x + + def __validate_feature_space_homogeneity(self) -> List[str]: + """ + Assert that the data sets which were kept have the same feature names. + """ + var_names = self.adatas[list(self.adatas.keys())[0]].var_names.tolist() + for k, v in self.adatas.items(): + assert len(var_names) == len(v.var_names), f"number of features in store differed in object {k}" + assert np.all(var_names == v.var_names), f"var_names in store were not matched in object {k}" + return var_names + def generator( self, + idx: Union[np.ndarray, None] = None, batch_size: int = 1, obs_keys: List[str] = [], - continuous_batches: bool = True, + return_dense: bool = True, ) -> iter: """ Yields an unbiased generator over observations in the contained data sets. - :param batch_size: Number of observations in each batch (generator invocation). + :param idx: Global idx to query from store. These is an array with indicies corresponding to a contiuous index + along all observations in self.adatas, ordered along a hypothetical concatenation along the keys of + self.adatas. + :param batch_size: Number of observations in each batch (generator invocation). Increasing this may result in + large speed-ups in query time but removes the ability of upstream generators to fully shuffle cells, as + these batches are the smallest data unit that upstream generators can access. :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available in self.adatas. - :param continuous_batches: Whether to build batches of batch_size across data set boundaries if end of one - data set is reached. + :param return_dense: Whether to return count data .X as dense batches. This allows more efficient feature + indexing if the store is sparse (column indexing on csr matrices is slow). :return: Generator function which yields batch_size at every invocation. The generator returns a tuple of (.X, .obs) with types: - - if store format is h5ad: (scipy.sparse.csr_matrix, pandas.DataFrame) + - if store format is h5ad: (Union[scipy.sparse.csr_matrix, np.ndarray], pandas.DataFrame) """ + # Make sure that features are ordered in the same way in each object so that generator yields consistent cell + # vectors. + _ = self.__validate_feature_space_homogeneity() + var_names_store = self.adatas[list(self.adatas.keys())[0]].var_names.tolist() + # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. + if self.genome_container is not None: + var_names_target = self.genome_container.ensembl + var_idx = np.sort([var_names_store.index(x) for x in var_names_target]) + # Check if index vector is just full ordered list of indices, in this case, sub-setting is unnecessary. + if len(var_idx) == len(var_names_store) and np.all(var_idx == np.arange(0, len(var_names_store))): + var_idx = None + else: + var_idx = None def generator() -> tuple: - n_datasets = len(list(self.adatas.keys())) - x_last = None - obs_last = None + global_index_set = dict(list(zip(list(self.adatas.keys()), self.indices_global))) for i, (k, v) in enumerate(self.adatas.items()): # Define batch partitions: - if continuous_batches and x_last is not None: - # Prepend data set with residual data from last data set. - remainder_start = x_last.shape[0] - n_obs = v.n_obs + remainder_start + # Get subset of target indices that fall into this data set. + # Use indices relative to this data (via .index here). + # continuous_slices is evaluated to establish whether slicing can be performed as the potentially + # faster [start:end] or needs to tbe index wise [indices] + if idx is not None: + idx_i = [global_index_set[k].tolist().index(x) for x in idx if x in global_index_set[k]] + idx_i = np.sort(idx_i) + continuous_slices = np.all(idx_i == np.arange(0, v.n_obs)) else: - # Partition into equally sized batches up to last batch. - remainder_start = 0 - n_obs = v.n_obs - remainder = n_obs % batch_size - batch_starts = [ - np.min([0, int(x * batch_size - remainder_start)]) - for x in np.arange(1, n_obs // batch_size + int(remainder > 0)) - ] - n_batches = len(batch_starts) - # Iterate over batches: - for j, x in enumerate(batch_starts): - batch_end = int(x + batch_size) - x = v.X[x:batch_end, :] - obs = v.obs[obs_keys].iloc[x:batch_end, :] - assert isinstance(x, scipy.sparse.csr_matrix), f"{type(x)}" - assert isinstance(obs, pd.DataFrame), f"{type(obs)}" - if continuous_batches and remainder > 0 and i < (n_datasets - 1) and j == (n_batches - 1): - # Cache incomplete last batch to append to next first batch of next data set. - x_last = x - obs_last = obs - elif continuous_batches and x_last is not None: - # Append last incomplete batch current batch. - x = scipy.sparse.hstack(blocks=[x_last, x], format="csr") - obs = pd.concat(objs=[obs_last, obs], axis=0) - yield x, obs - else: + idx_i = np.arange(0, v.n_obs) + continuous_slices = True + if len(idx_i) > 0: # Skip data objects without matched cells. + n_obs = len(idx_i) + # Cells left over after batching to batch size, accounting for overhang: + remainder = n_obs % batch_size + batch_starts_ends = [ + (int(x * batch_size), int(x * batch_size) + batch_size) + for x in np.arange(0, n_obs // batch_size + int(remainder > 0)) + ] + # Iterate over batches: + for j, (s, e) in enumerate(batch_starts_ends): + if continuous_slices: + e = idx_i[e] if e < n_obs else n_obs + x = v.X[idx_i[s]:e, :] + else: + x = v.X[idx_i[s:e], :] + # Do dense conversion now so that col-wise indexing is not slow, often, dense conversion + # would be done later anyway. + if return_dense: + x = x.todense() + if var_idx is not None: + x = x[:, var_idx] + if continuous_slices: + e = idx_i[e] if e < n_obs else n_obs + obs = v.obs[obs_keys].iloc[idx_i[s]:e, :] + else: + obs = v.obs[obs_keys].iloc[idx_i[s:e], :] + assert isinstance(obs, pd.DataFrame), f"{type(obs)}" # Yield current batch. yield x, obs @@ -134,45 +186,7 @@ def celltypes_universe(self) -> CelltypeUniverse: ) return self._celltype_universe - def subset(self, attr_key, values): - """ - Subset list of adata objects based on match to values in key property. - - Keys need to be available in adata.uns - - :param attr_key: Property to subset by. - :param values: Classes to overlap to. - :return: - """ - if isinstance(values, np.ndarray): - values = values.tolist() - if isinstance(values, tuple): - values = list(values) - if not isinstance(values, list): - values = [values] - # Get ontology container to be able to do relational reasoning: - ontology = getattr(self.ontology_container, attr_key) - for k in list(self.adatas.keys()): - if getattr(self._adata_ids_sfaira, attr_key) in self.adatas[k].uns.keys(): - if getattr(self._adata_ids_sfaira, attr_key) != UNS_STRING_META_IN_OBS: - values_found = self.adatas[k].uns[getattr(self._adata_ids_sfaira, attr_key)] - else: - values_found = self.adatas[k].obs[getattr(self._adata_ids_sfaira, attr_key)].values.tolist() - if not isinstance(values_found, list): - values_found = [values_found] - if not np.any([ - np.any([ - is_child(query=x, ontology=ontology, ontology_parent=y) - for y in values - ]) for x in values_found - ]): - # Delete entries which a non-matching meta data value associated with this item. - del self.adatas[k] - else: - # Delete entries which did not have this key annotated. - del self.adatas[k] - - def subset_cells_idx(self, attr_key, values: Union[str, List[str]]): + def _get_subset_idx(self, attr_key, values: Union[str, List[str]]): """ Get indices of subset list of adata objects based on cell-wise properties. @@ -197,25 +211,27 @@ def subset_cells_idx(self, attr_key, values: Union[str, List[str]]): values = [values] def get_subset_idx(adata, k, dataset): - # Try to look first in cell wise annotation to use cell-wise map if data set-wide maps are ambiguous: + # Use cell-wise annotation if data set-wide maps are ambiguous: # This can happen if the different cell-wise annotations are summarised as a union in .uns. - if getattr(self._adata_ids_sfaira, k) in adata.obs.keys(): - values_found = adata.obs[getattr(self._adata_ids_sfaira, k)].values - elif getattr(self._adata_ids_sfaira, k) in adata.uns.keys(): + if getattr(self._adata_ids_sfaira, k) in adata.uns.keys() and \ + adata.uns[getattr(self._adata_ids_sfaira, k)] != UNS_STRING_META_IN_OBS: values_found = adata.uns[getattr(self._adata_ids_sfaira, k)] if isinstance(values_found, np.ndarray): values_found = values_found.tolist() elif not isinstance(values_found, list): values_found = [values_found] if len(values_found) > 1: - print(f"WARNING: subsetting not exact for attribute {k}: {values_found}," - f" discarding data set {dataset}.") - values_found = [] + values_found = None # Go to cell-wise annotation. else: # Replicate unique property along cell dimension. values_found = [values_found[0] for i in range(adata.n_obs)] else: - raise ValueError(f"did not find attribute {k} in data set {dataset}") + values_found = None + if values_found is None: + if getattr(self._adata_ids_sfaira, k) in adata.obs.keys(): + values_found = adata.obs[getattr(self._adata_ids_sfaira, k)].values + else: + raise ValueError(f"did not find unique attribute {k} in data set {dataset}") values_found_unique = np.unique(values_found) try: ontology = getattr(self.ontology_container, k) @@ -238,10 +254,10 @@ def get_subset_idx(adata, k, dataset): idx_old = self.indices[k].tolist() idx_new = get_subset_idx(adata=v, k=attr_key, dataset=k) # Keep intersection of old and new hits. - indices[k] = np.array(list(set(idx_old).intersection(set(idx_new)))) + indices[k] = np.asarray(list(set(idx_old).intersection(set(idx_new))), dtype="int32") return indices - def subset_cells(self, attr_key, values: Union[str, List[str]]): + def subset(self, attr_key, values: Union[str, List[str]]): """ Subset list of adata objects based on cell-wise properties. @@ -263,13 +279,13 @@ def subset_cells(self, attr_key, values: Union[str, List[str]]): - "state_exact" points to self.state_exact_obs_key :param values: Classes to overlap to. """ - self.indices = self.subset_cells_idx(attr_key=attr_key, values=values) + self.indices = self._get_subset_idx(attr_key=attr_key, values=values) for k, v in self.indices.items(): if v.shape[0] == 0: # No observations (cells) left. del self.adatas[k] - def subset_cells_idx_global(self, attr_key, values: Union[str, List[str]]): + def subset_cells_idx_global(self, attr_key, values: Union[str, List[str]]) -> np.ndarray: """ Get indices of subset list of adata objects based on cell-wise properties treating instance as single array. @@ -293,15 +309,27 @@ def subset_cells_idx_global(self, attr_key, values: Union[str, List[str]]): :return Index vector """ # Get indices of of cells in target set by file. - idx_by_dataset = self.subset_cells_idx(attr_key=attr_key, values=values) + idx_by_dataset = self._get_subset_idx(attr_key=attr_key, values=values) # Translate file-wise indices into global index list across all data sets. idx = [] counter = 0 + for k, v in idx_by_dataset.items(): + idx.extend((v + counter).tolist()) + counter += self.adatas[k].n_obs + return np.asarray(idx) + + @property + def indices_global(self): + """ + Increasing indices across data sets which can be concatenated into a single index vector with unique entries + for cells. + """ + counter = 0 + indices = [] for k, v in self.adatas.items(): - idx_k = np.arange(counter, counter + v.n_obs) - idx.extend(idx_k[idx_by_dataset[k]]) + indices.append(np.arange(counter, counter + v.n_obs)) counter += v.n_obs - return idx + return indices def write_config(self, fn: Union[str, os.PathLike]): """ @@ -331,11 +359,37 @@ def load_config(self, fn: Union[str, os.PathLike]): # Only retain data sets with which are mentioned in config file. self.subset(attr_key="id", values=list(self.indices.keys())) + @property + def var_names(self): + var_names = self.__validate_feature_space_homogeneity() + # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. + if self.genome_container is None: + return var_names + else: + return self.genome_container.ensembl + @property def n_vars(self): - # assumes that all adata - return list(self.adatas.values())[0].n_vars + var_names = self.__validate_feature_space_homogeneity() + # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. + if self.genome_container is None: + return len(var_names) + else: + return self.genome_container.n_var @property def n_obs(self): - return np.sum([len(v) for _, v in self.indices]) + return np.sum([len(v) for v in self.indices.values()]) + + @property + def shape(self): + return [self.n_obs, self.n_vars] + + @property + def obs(self) -> pd.DataFrame: + """ + Assemble .obs table of subset of full data. + + :return: .obs data frame. + """ + return pd.concat([v.obs for v in self.adatas.values()], axis=0) diff --git a/sfaira/estimators/keras.py b/sfaira/estimators/keras.py index 2ccac6b21..0f504bc3c 100644 --- a/sfaira/estimators/keras.py +++ b/sfaira/estimators/keras.py @@ -2,6 +2,7 @@ import anndata import hashlib import numpy as np +import pandas as pd import scipy.sparse try: import tensorflow as tf @@ -23,13 +24,16 @@ def prepare_sf(x): + """ + Uses a minimal size factor of 1e-3 for total counts / 1e4 + """ if len(x.shape) == 2: sf = np.asarray(x.sum(axis=1)).flatten() elif len(x.shape) == 1: sf = np.asarray(x.sum()).flatten() else: raise ValueError("x.shape > 2") - sf = np.log(sf / 1e4 + 1e-10) + sf = np.log(np.maximum(sf / 1e4, 1e-3)) return sf @@ -65,6 +69,9 @@ def __init__( self.model_id = model_id self.model_class = model_class self.topology_container = model_topology + # Prepare store with genome container sub-setting: + if isinstance(self.data, DistributedStore): + self.data.genome_container = self.topology_container.gc self.history = None self.train_hyperparam = None @@ -176,8 +183,9 @@ def _get_dataset( batch_size: Union[int, None], mode: str, shuffle_buffer_size: int, - prefetch: int, - weighted: bool + cache_full: bool, + weighted: bool, + retrieval_batch_size: int, ): pass @@ -195,7 +203,14 @@ def _get_class_dict( label_dict.update({label: float(i)}) return label_dict - def _prepare_data_matrix(self, idx: Union[np.ndarray, None]): + def _prepare_data_matrix(self, idx: Union[np.ndarray, None]) -> scipy.sparse.csr_matrix: + """ + Subsets observations x features matrix in .data to observation indices (idx, the split) and features defined + by topology. + + :param idx: Observation index split. + :return: Data matrix + """ # Check that AnnData is not backed. If backed, assume that these processing steps were done before. if self.data.isbacked: raise ValueError("tried running backed AnnData object through standard pipeline") @@ -224,35 +239,19 @@ def _prepare_data_matrix(self, idx: Union[np.ndarray, None]): return x # Compute indices of genes to keep - data_ids = self.data.var[self._adata_ids.gene_id_ensembl].values - idx_feature_kept = np.where([x in self.topology_container.gc.ensembl for x in data_ids])[0] - idx_feature_map = np.array([self.topology_container.gc.ensembl.index(x) - for x in data_ids[idx_feature_kept]]) - - # Convert to csc and remove unmapped genes - x = x.tocsc() - x = x[:, idx_feature_kept] - - # Create reordered feature matrix based on reference and convert to csr - x_new = scipy.sparse.csc_matrix((x.shape[0], self.topology_container.n_var), dtype=x.dtype) - # copying this over to the new matrix in chunks of size `steps` prevents a strange scipy error: - # ... scipy/sparse/compressed.py", line 922, in _zero_many i, j, offsets) - # ValueError: could not convert integer scalar - step = 500 - if step < len(idx_feature_map): - for i in range(0, len(idx_feature_map), step): - x_new[:, idx_feature_map[i:i + step]] = x[:, i:i + step] - x_new[:, idx_feature_map[i + step:]] = x[:, i + step:] - else: - x_new[:, idx_feature_map] = x - - x_new = x_new.tocsr() - - print(f"found {len(idx_feature_kept)} intersecting features between {x.shape[1]} " - f"features in input data set and {self.topology_container.n_var} features in reference genome") - print(f"found {x_new.shape[0]} observations") - - return x_new + data_ids = self.data.var[self._adata_ids.gene_id_ensembl].values.tolist() + target_ids = self.topology_container.gc.ensembl + idx_map = np.array([data_ids.index(z) for z in target_ids]) + # Assert that each ID from target IDs appears exactly once in data IDs: + assert np.all([z in data_ids for z in target_ids]), "not all target feature IDs found in data" + assert np.all([np.sum(z == np.array(data_ids)) <= 1. for z in target_ids]), \ + "duplicated target feature IDs exist in data" + # Map feature space. + x = x[:, idx_map] + print(f"found {len(idx_map)} intersecting features between {x.shape[1]} features in input data set and" + f" {self.topology_container.n_var} features in reference genome") + print(f"found {x.shape[0]} observations") + return x @abc.abstractmethod def _get_loss(self): @@ -283,6 +282,7 @@ def train( test_split: Union[float, dict] = 0., validation_batch_size: int = 256, max_validation_steps: Union[int, None] = 10, + cache_full: bool = False, patience: int = 20, lr_schedule_min_lr: float = 1e-5, lr_schedule_factor: float = 0.2, @@ -290,7 +290,7 @@ def train( shuffle_buffer_size: int = int(1e4), log_dir: Union[str, None] = None, callbacks: Union[list, None] = None, - weighted: bool = True, + weighted: bool = False, verbose: int = 2 ): """ @@ -350,7 +350,7 @@ def train( } # Set callbacks. - cbs = [] + cbs = [tf.keras.callbacks.TerminateOnNaN()] if patience is not None and patience > 0: cbs.append(tf.keras.callbacks.EarlyStopping( monitor='val_loss', @@ -391,24 +391,19 @@ def train( if isinstance(test_split, float) or isinstance(test_split, int): self.idx_test = np.random.choice( a=all_idx, - size=round(self.data.shape[0] * test_split), + size=round(self.data.n_obs * test_split), replace=False, ) elif isinstance(test_split, dict): - if isinstance(self.data, anndata.AnnData): - in_test = np.ones((self.data.obs.shape[0],), dtype=int) == 1 - for k, v in test_split.items(): - if isinstance(v, list): - in_test = np.logical_and(in_test, np.array([x in v for x in self.data.obs[k].values])) - else: - in_test = np.logical_and(in_test, self.data.obs[k].values == v) - self.idx_test = np.where(in_test)[0] - print(f"Found {len(self.idx_test)} out of {self.data.n_obs} cells that correspond to held out data set") - print(self.idx_test) - else: - assert len(test_split.values()) == 1 - self.idx_test = self.data.subset_cells_idx_global(attr_key=list(test_split.keys())[0], - values=list(test_split.values())[0]) + in_test = np.ones((self.data.n_obs,), dtype=int) == 1 + for k, v in test_split.items(): + if isinstance(v, list): + in_test = np.logical_and(in_test, np.array([x in v for x in self.data.obs[k].values])) + else: + in_test = np.logical_and(in_test, self.data.obs[k].values == v) + self.idx_test = np.where(in_test)[0] + print(f"Found {len(self.idx_test)} out of {self.data.n_obs} cells that correspond to held out data set") + print(self.idx_test) else: raise ValueError("type of test_split %s not recognized" % type(test_split)) idx_train_eval = np.array([x for x in all_idx if x not in self.idx_test]) @@ -434,14 +429,16 @@ def train( batch_size=batch_size, mode='train', shuffle_buffer_size=min(shuffle_buffer_size, len(self.idx_train)), - weighted=weighted + weighted=weighted, + cache_full=cache_full, ) eval_dataset = self._get_dataset( idx=self.idx_eval, batch_size=validation_batch_size, mode='train_val', shuffle_buffer_size=min(shuffle_buffer_size, len(self.idx_eval)), - weighted=weighted + weighted=weighted, + cache_full=cache_full, ) steps_per_epoch = min(max(len(self.idx_train) // batch_size, 1), max_steps_per_epoch) @@ -469,6 +466,18 @@ def get_citations(self): def using_store(self) -> bool: return isinstance(self.data, DistributedStore) + @property + def obs_train(self): + return self.data.obs.iloc[self.idx_train, :] + + @property + def obs_eval(self): + return self.data.obs.iloc[self.idx_eval, :] + + @property + def obs_test(self): + return self.data.obs.iloc[self.idx_test, :] + class EstimatorKerasEmbedding(EstimatorKeras): """ @@ -539,6 +548,7 @@ def _get_base_generator( self, generator_helper, idx: Union[np.ndarray, None], + batch_size: int = 1, ): """ Yield a basic generator based on which a tf dataset can be built. @@ -557,30 +567,40 @@ def _get_base_generator( # Prepare data reading according to whether anndata is backed or not: if self.using_store: generator_raw = self.data.generator( - batch_size=1, + idx=idx, + batch_size=batch_size, obs_keys=[], - continuous_batches=True, + return_dense=True, ) def generator(): - counter = -1 - for z in generator_raw: - counter += 1 - if counter in idx: - x_sample = z[0].toarray().flatten() - yield generator_helper(x_sample=x_sample) + for z in generator_raw(): + x_sample = z[0] + if isinstance(x_sample, scipy.sparse.csr_matrix): + x_sample = x_sample.todense() + x_sample = np.asarray(x_sample) + for i in range(x_sample.shape[0]): + yield generator_helper(x_sample=x_sample[i]) n_features = self.data.n_vars n_samples = self.data.n_obs else: x = self.data.X if self.data.isbacked else self._prepare_data_matrix(idx=idx) + indices = idx if self.data.isbacked else range(x.shape[0]) + n_obs = len(indices) + remainder = n_obs % batch_size + batch_starts_ends = [ + (int(x * batch_size), int(x * batch_size) + batch_size) + for x in np.arange(0, n_obs // batch_size + int(remainder > 0)) + ] def generator(): is_sparse = isinstance(x[0, :], scipy.sparse.spmatrix) - indices = idx if self.data.isbacked else range(x.shape[0]) - for i in indices: - x_sample = x[i, :].toarray().flatten() if is_sparse else x[i, :].flatten() - yield generator_helper(x_sample=x_sample) + for s, e in batch_starts_ends: + x_sample = np.asarray(x[indices[s:e], :].todense()) if is_sparse \ + else x[indices[s:e], :] + for i in range(x_sample.shape[0]): + yield generator_helper(x_sample=x_sample[i]) n_features = x.shape[1] n_samples = x.shape[0] @@ -593,8 +613,9 @@ def _get_dataset( batch_size: Union[int, None], mode: str, shuffle_buffer_size: int = int(1e7), - prefetch: int = 10, + cache_full: bool = False, weighted: bool = False, + retrieval_batch_size: int = 128, ): """ @@ -621,6 +642,7 @@ def generator_helper(x_sample): generator, n_samples, n_features = self._get_base_generator( generator_helper=generator_helper, idx=idx, + batch_size=retrieval_batch_size, ) output_types, output_shapes = self._get_output_dim(n_features=n_features, model_type=model_type, mode=mode) dataset = tf.data.Dataset.from_generator( @@ -628,6 +650,8 @@ def generator_helper(x_sample): output_types=output_types, output_shapes=output_shapes ) + if cache_full: + dataset = dataset.cache() # Only shuffle in train modes if mode in ['train', 'train_val']: dataset = dataset.repeat() @@ -635,30 +659,31 @@ def generator_helper(x_sample): buffer_size=min(n_samples, shuffle_buffer_size), seed=None, reshuffle_each_iteration=True) - dataset = dataset.batch(batch_size).prefetch(prefetch) + dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE) return dataset - elif mode == 'gradient_method': + elif mode == 'gradient_method': # TODO depreceate this code # Prepare data reading according to whether anndata is backed or not: cell_to_class = self._get_class_dict(obs_key=self._adata_ids.cell_ontology_class) if self.using_store: n_features = self.data.n_vars generator_raw = self.data.generator( + idx=idx, batch_size=1, obs_keys=["cell_ontology_class"], - continuous_batches=True, + return_dense=True, ) def generator(): - counter = -1 - for z in generator_raw: - counter += 1 - if counter in idx: - x_sample = z[0].toarray().flatten() - sf_sample = prepare_sf(x=x_sample)[0] - y_sample = z[1]["cell_ontology_class"].values[0] - yield (x_sample, sf_sample), (x_sample, cell_to_class[y_sample]) + for z in generator_raw(): + x_sample = z[0] + if isinstance(x_sample, scipy.sparse.csr_matrix): + x_sample = x_sample.todense() + x_sample = np.asarray(x_sample).flatten() + sf_sample = prepare_sf(x=x_sample)[0] + y_sample = z[1]["cell_ontology_class"].values[0] + yield (x_sample, sf_sample), (x_sample, cell_to_class[y_sample]) elif isinstance(self.data, anndata.AnnData) and self.data.isbacked: n_features = self.data.X.shape[1] @@ -691,7 +716,7 @@ def generator(): buffer_size=shuffle_buffer_size, seed=None, reshuffle_each_iteration=True - ).batch(batch_size).prefetch(prefetch) + ).batch(batch_size).prefetch(tf.data.AUTOTUNE) return dataset @@ -752,7 +777,8 @@ def evaluate_any(self, idx, batch_size: int = 1, max_steps: int = np.inf): dataset = self._get_dataset( idx=idx, batch_size=batch_size, - mode='eval' + mode='eval', + retrieval_batch_size=128, ) steps = min(max(len(idx) // batch_size, 1), max_steps) results = self.model.training_model.evaluate(x=dataset, steps=steps) @@ -760,7 +786,7 @@ def evaluate_any(self, idx, batch_size: int = 1, max_steps: int = np.inf): else: return {} - def evaluate(self, batch_size: int = 1, max_steps: int = np.inf): + def evaluate(self, batch_size: int = 64, max_steps: int = np.inf): """ Evaluate the custom model on test data. @@ -772,7 +798,7 @@ def evaluate(self, batch_size: int = 1, max_steps: int = np.inf): """ return self.evaluate_any(idx=self.idx_test, batch_size=batch_size, max_steps=max_steps) - def predict(self): + def predict(self, batch_size: int = 64, max_steps: int = np.inf): """ return the prediction of the model @@ -782,14 +808,15 @@ def predict(self): if self.idx_test is None or self.idx_test.any(): # true if the array is not empty or if the passed value is None dataset = self._get_dataset( idx=self.idx_test, - batch_size=64, - mode='predict' + batch_size=batch_size, + mode='predict', + retrieval_batch_size=128, ) return self.model.predict_reconstructed(x=dataset) else: return np.array([]) - def predict_embedding(self): + def predict_embedding(self, batch_size: int = 64, max_steps: int = np.inf): """ return the prediction in the latent space (z_mean for variational models) @@ -799,14 +826,15 @@ def predict_embedding(self): if self.idx_test is None or self.idx_test.any(): # true if the array is not empty or if the passed value is None dataset = self._get_dataset( idx=self.idx_test, - batch_size=64, - mode='predict' + batch_size=batch_size, + mode='predict', + retrieval_batch_size=128, ) return self.model.predict_embedding(x=dataset, variational=False) else: return np.array([]) - def predict_embedding_variational(self): + def predict_embedding_variational(self, batch_size: int = 64, max_steps: int = np.inf): """ return the prediction of z, z_mean, z_log_var in the variational latent space @@ -816,8 +844,9 @@ def predict_embedding_variational(self): if self.idx_test is None or self.idx_test: # true if the array is not empty or if the passed value is None dataset = self._get_dataset( idx=self.idx_test, - batch_size=64, - mode='predict' + batch_size=batch_size, + mode='predict', + retrieval_batch_size=128, ) return self.model.predict_embedding(x=dataset, variational=True) else: @@ -979,15 +1008,21 @@ def ontology_ids(self): def _one_hot_encoder(self): - def encoder(x): - idx = self.celltype_universe.onto_cl.map_to_leaves( - node=x, - return_type="idx", - include_self=True, - ) - y = np.zeros((self.ntypes,), dtype="float32") - y[idx] = 1. / len(idx) - return y + def encoder(x) -> np.ndarray: + if isinstance(x, str): + x = [x] + idx = [ + self.celltype_universe.onto_cl.map_to_leaves( + node=y, + return_type="idx", + include_self=True, + ) + for y in x + ] + oh = np.zeros((len(x), self.ntypes,), dtype="float32") + for i, y in enumerate(idx): + oh[i, y] = 1. / len(y) + return oh return encoder @@ -1006,7 +1041,7 @@ def _get_celltype_out( # One whether "unknown" is already included, otherwise add one extra column. onehot_encoder = self._one_hot_encoder() y = np.concatenate([ - np.expand_dims(onehot_encoder(z), axis=0) + onehot_encoder(z) for z in self.data.obs[self._adata_ids.cell_ontology_class].values[idx].tolist() ], axis=0) # Distribute aggregated class weight for computation of weights: @@ -1039,6 +1074,7 @@ def _get_base_generator( generator_helper, idx: Union[np.ndarray, None], weighted: bool = False, + batch_size: int = 1, ): """ Yield a basic generator based on which a tf dataset can be built. @@ -1061,20 +1097,22 @@ def _get_base_generator( if weighted: raise ValueError("using weights with store is not supported yet") generator_raw = self.data.generator( - batch_size=1, + idx=idx, + batch_size=batch_size, obs_keys=["cell_ontology_class"], - continuous_batches=True, + return_dense=True, ) onehot_encoder = self._one_hot_encoder() def generator(): - counter = -1 - for z in generator_raw: - counter += 1 - if counter in idx: - x_sample = z[0].toarray().flatten() - y_sample = onehot_encoder(z[0]["cell_ontology_class"].values[0]) - yield generator_helper(x_sample, y_sample, 1.) + for z in generator_raw(): + x_sample = z[0] + if isinstance(x_sample, scipy.sparse.csr_matrix): + x_sample = x_sample.todense() + x_sample = np.asarray(x_sample) + y_sample = onehot_encoder(z[1]["cell_ontology_class"].values) + for i in range(x_sample.shape[0]): + yield generator_helper(x_sample[i], y_sample[i], 1.) n_features = self.data.n_vars n_samples = self.data.n_obs @@ -1084,15 +1122,23 @@ def generator(): if not weighted: weights = np.ones_like(weights) x = self.data.X if self.data.isbacked else self._prepare_data_matrix(idx=idx) + is_sparse = isinstance(x, scipy.sparse.spmatrix) + indices = idx if self.data.isbacked else range(x.shape[0]) + n_obs = len(indices) + remainder = n_obs % batch_size + batch_starts_ends = [ + (int(x * batch_size), int(x * batch_size) + batch_size) + for x in np.arange(0, n_obs // batch_size + int(remainder > 0)) + ] def generator(): - is_sparse = isinstance(x[0, :], scipy.sparse.spmatrix) - indices = idx if self.data.isbacked else range(x.shape[0]) - for i in indices: - x_sample = np.asarray(x[i, :].todense()).flatten() if is_sparse else x[i, :].flatten() - y_sample = y[i, :] - w_sample = weights[i] - yield generator_helper(x_sample, y_sample, w_sample) + for s, e in batch_starts_ends: + x_sample = np.asarray(x[indices[s:e], :].todense()) if is_sparse \ + else x[indices[s:e], :] + y_sample = y[indices[s:e], :] + w_sample = weights[indices[s:e]] + for i in range(x_sample.shape[0]): + yield generator_helper(x_sample[i], y_sample[i], w_sample[i]) n_features = x.shape[1] n_samples = x.shape[0] @@ -1106,8 +1152,9 @@ def _get_dataset( batch_size: Union[int, None], mode: str, shuffle_buffer_size: int = int(1e7), - prefetch: int = 10, - weighted: bool = True, + cache_full: bool = False, + weighted: bool = False, + retrieval_batch_size: int = 128, ): """ @@ -1129,6 +1176,7 @@ def generator_helper(x_sample, y_sample, w_sample): generator_helper=generator_helper, idx=idx, weighted=weighted, + batch_size=retrieval_batch_size, ) output_types, output_shapes = self._get_output_dim(n_features=n_features, n_labels=n_labels, mode=mode) dataset = tf.data.Dataset.from_generator( @@ -1136,6 +1184,8 @@ def generator_helper(x_sample, y_sample, w_sample): output_types=output_types, output_shapes=output_shapes ) + if cache_full: + dataset = dataset.cache() if mode == 'train' or mode == 'train_val': dataset = dataset.repeat() dataset = dataset.shuffle( @@ -1143,7 +1193,7 @@ def generator_helper(x_sample, y_sample, w_sample): seed=None, reshuffle_each_iteration=True ) - dataset = dataset.batch(batch_size).prefetch(prefetch) + dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE) return dataset @@ -1160,11 +1210,7 @@ def _metrics(self): CustomTprClasswise(k=self.ntypes) ] - def predict( - self, - batch_size: int = 1, - max_steps: int = np.inf, - ): + def predict(self, batch_size: int = 64, max_steps: int = np.inf): """ Return the prediction of the model @@ -1177,14 +1223,15 @@ def predict( dataset = self._get_dataset( idx=idx, batch_size=batch_size, - mode='predict' + mode='predict', + retrieval_batch_size=128, ) steps = min(max(len(idx) // batch_size, 1), max_steps) return self.model.training_model.predict(x=dataset, steps=steps) else: return np.array([]) - def ytrue(self): + def ytrue(self, batch_size: int = 64, max_steps: int = np.inf): """ Return the true labels of the test set. @@ -1193,7 +1240,7 @@ def ytrue(self): if self.idx_test is None or self.idx_test.any(): # true if the array is not empty or if the passed value is None x, y, w = self._get_dataset( idx=self.idx_test, - batch_size=None, + batch_size=batch_size, mode='eval' ) return y @@ -1205,7 +1252,7 @@ def evaluate_any( idx, batch_size: int = 1, max_steps: int = np.inf, - weighted: bool = True + weighted: bool = False ): """ Evaluate the custom model on any local data. @@ -1224,7 +1271,8 @@ def evaluate_any( idx=idx, batch_size=batch_size, mode='eval', - weighted=weighted + weighted=weighted, + retrieval_batch_size=128, ) steps = min(max(len(idx) // batch_size, 1), max_steps) results = self.model.training_model.evaluate(x=dataset, steps=steps) @@ -1232,7 +1280,7 @@ def evaluate_any( else: return {} - def evaluate(self, batch_size: int = 1, max_steps: int = np.inf, weighted: bool = True): + def evaluate(self, batch_size: int = 64, max_steps: int = np.inf, weighted: bool = False): """ Evaluate the custom model on local data. diff --git a/sfaira/interface/__init__.py b/sfaira/interface/__init__.py index 51dee4b72..7c96cac34 100644 --- a/sfaira/interface/__init__.py +++ b/sfaira/interface/__init__.py @@ -1,2 +1,2 @@ -from sfaira.interface.model_zoo import ModelZoo, ModelZooEmbedding, ModelZooCelltype +from sfaira.interface.model_zoo import ModelZoo from sfaira.interface.user_interface import UserInterface diff --git a/sfaira/interface/model_zoo.py b/sfaira/interface/model_zoo.py index 1267b153f..89d3a69a6 100644 --- a/sfaira/interface/model_zoo.py +++ b/sfaira/interface/model_zoo.py @@ -1,15 +1,11 @@ import abc -try: - import kipoi -except ImportError: - kipoi = None import numpy as np import pandas as pd from typing import List, Union from sfaira.versions.metadata import CelltypeUniverse from sfaira.consts import OntologyContainerSfaira -from sfaira.versions.topologies import TopologyContainer +from sfaira.versions.topologies import TopologyContainer, TOPOLOGIES class ModelZoo(abc.ABC): @@ -18,39 +14,95 @@ class ModelZoo(abc.ABC): """ topology_container: TopologyContainer ontology: dict - model_id: Union[str, None] - model_class: Union[str, None] - model_class: Union[str, None] - model_type: Union[str, None] - model_topology: Union[str, None] - model_version: Union[str, None] + _model_id: Union[str, None] celltypes: Union[CelltypeUniverse, None] def __init__( self, - model_lookuptable: Union[None, pd.DataFrame] = None + model_lookuptable: Union[None, pd.DataFrame] = None, + model_class: Union[str, None] = None, ): """ :param model_lookuptable: model_lookuptable. + :param model_class: Model class to subset to. """ self._ontology_container_sfaira = OntologyContainerSfaira() if model_lookuptable is not None: # check if models in repository - self.ontology = self.load_ontology_from_model_ids(model_lookuptable['model_id'].values) - self.model_id = None - self.model_class = None - self.model_type = None - self.organisation = None - self.model_topology = None - self.model_version = None - self.topology_container = None + self.ontology = self.load_ontology_from_model_ids(model_ids=model_lookuptable['model_id'].values, + model_class=model_class) + self._model_id = None self.celltypes = None - @abc.abstractmethod + @property + def model_class(self): + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[0] + + @property + def model_name(self): + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[1] + + @property + def model_organism(self): + # TODO: this is a custom name ontology + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[1].split("-")[0] + + @property + def model_organ(self): + # TODO: this is a custom name ontology + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[1].split("-")[1] + + @property + def model_type(self): + # TODO: this is a custom name ontology + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[1].split("-")[2] + + @property + def model_topology(self): + # TODO: this is a custom name ontology + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[1].split("-")[3] + + @property + def model_version(self): + # TODO: this is a custom name ontology + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[1].split("-")[4] + + @property + def organisation(self): + assert self.model_id is not None, "set model_id first" + return self.model_id.split('_')[2] + def load_ontology_from_model_ids( self, - model_ids - ): - pass + model_ids, + model_class: Union[str, None] = None, + ) -> dict: + """ + Load model ontology based on models available in model lookup tables. + + :param model_ids: Table listing all available model_ids. + :param model_class: Model class to subset to. + :return: Dictionary formatted ontology. + """ + + ids = [x for x in model_ids if (x.split('_')[0] == model_class or model_class is None)] + id_df = pd.DataFrame( + [i.split('_')[1:6] for i in ids], + columns=['name', 'organisation'] + ) + model = np.unique(id_df['name']) + ontology = dict.fromkeys(model) + for m in model: + id_df_m = id_df[id_df.model_type == m] + orga = np.unique(id_df_m['organisation']) + ontology[m] = dict.fromkeys(orga) + return ontology def _order_versions( self, @@ -66,25 +118,19 @@ def _order_versions( return versions - def set_model_id( - self, - model_id: str - ): + @property + def model_id(self): + return self._model_id + + @model_id.setter + def model_id(self, x: str): """ Set model ID to a manually supplied ID. - :param model_id: Model ID to set. Format: pipeline_genome_organ_model_organisation_topology_version + :param x: Model ID to set. Format: pipeline_genome_organ_model_organisation_topology_version """ - if len(model_id.split('_')) < 6: - raise RuntimeError(f'Model ID {model_id} is invalid!') - self.model_id = model_id - ixs = self.model_id.split('_') - self.model_class = ixs[0] - self.model_id = ixs[1] - self.model_type = ixs[2] - self.organisation = ixs[3] - self.model_topology = ixs[4] - self.model_version = ixs[5] + assert len(x.split('_')) == 3, f'model_id {x} is invalid' + self._model_id = x def save_weights_to_remote(self, path=None): """ @@ -113,14 +159,6 @@ def call_kipoi(self): """ raise NotImplementedError() - def models(self) -> List[str]: - """ - Return list of available models. - - :return: List of models available. - """ - return self.ontology.keys() - def topology( self, model_type: str, @@ -164,171 +202,11 @@ def model_hyperparameters(self) -> dict: assert self.topology_container is not None return self.topology_container.topology["hyper_parameters"] - -class ModelZooEmbedding(ModelZoo): - - """ - The supported model ontology is: - - organism -> organ -> model -> organisation -> topology -> version -> ID - - Maybe: include experimental protocol? Ie droplet, full-length, single-nuclei. - """ - - def load_ontology_from_model_ids( - self, - model_ids - ) -> dict: - """ - Load model ontology based on models available in model lookup tables. - - :param model_ids: Table listing all available model_ids. - :return: Dictionary formatted ontology. - """ - - ids = [i for i in model_ids if i.split('_')[0] == 'embedding'] - id_df = pd.DataFrame( - [i.split('_')[1:6] for i in ids], - columns=['id', 'model_type', 'organisation', 'model_topology', 'model_version'] - ) - model = np.unique(id_df['model_type']) - ontology = dict.fromkeys(model) - for m in model: - id_df_m = id_df[id_df.model_type == m] - orga = np.unique(id_df_m['organisation']) - ontology[m] = dict.fromkeys(orga) - for org in orga: - id_df_org = id_df_m[id_df_m.organisation == org] - topo = np.unique(id_df_org['model_topology']) - ontology[m][org] = dict.fromkeys(topo) - for t in topo: - id_df_t = id_df_org[id_df_org.model_topology == t] - ontology[m][org][t] = id_df_t.model_version.tolist() - - return ontology - - def set_latest( - self, - model_type: str, - organisation: str, - model_topology: str - ): - """ - Set model ID to latest model in given ontology group. - - :param model_type: Identifier of model_type to select. - :param organisation: Identifier of organisation to select. - :param model_topology: Identifier of model_topology to select - :return: - """ - assert model_type in self.ontology.keys(), "model_type requested was not found in ontology" - assert organisation in self.ontology[model_type].keys(), \ - "organisation requested was not found in ontology" - assert model_topology in self.ontology[model_type][organisation].keys(), \ - "model_topology requested was not found in ontology" - - versions = self.versions( - model_type=model_type, - organisation=organisation, - model_topology=model_topology - ) - self.model_type = model_type - self.organisation = organisation - self.model_topology = model_topology # set to model for now, could be organism/organ specific later - - self.model_version = self._order_versions(versions=versions)[0] - self.model_id = '_'.join([ - 'embedding', - self.id, - self.model_type, - self.organisation, - self.model_topology, - self.model_version - ]) - - -class ModelZooCelltype(ModelZoo): - """ - The supported model ontology is: - - organism -> organ -> model -> organisation -> topology -> version -> ID - - Maybe: include experimental protocol? Ie droplet, full-length, single-nuclei. - - Note on topology id: The topology ID is x.y.z, x is the major cell type version and y.z is the cell type model - topology. Cell type model ontologies do not include the output size as this is set by the cell type version. - """ - - def load_ontology_from_model_ids( - self, - model_ids - ) -> dict: - """ - Load model ontology based on models available in model lookup tables. - - :param model_ids: Table listing all available model_ids. - :return: Dictionary formatted ontology. - """ - - ids = [i for i in model_ids if i.split('_')[0] == 'celltype'] - id_df = pd.DataFrame( - [i.split('_')[1:6] for i in ids], - columns=['id', 'model_type', 'organisation', 'model_topology', 'model_version'] - ) - model = np.unique(id_df['model_type']) - ontology = dict.fromkeys(model) - for m in model: - id_df_m = id_df[id_df.model_type == m] - orga = np.unique(id_df_m['organisation']) - ontology[m] = dict.fromkeys(orga) - for org in orga: - id_df_org = id_df_m[id_df_m.organisation == org] - topo = np.unique(id_df_org['model_topology']) - ontology[m][org] = dict.fromkeys(topo) - for t in topo: - id_df_t = id_df_org[id_df_org.model_topology == t] - ontology[m][org][t] = id_df_t.model_version.tolist() - - return ontology - - def set_latest( - self, - model_type: str, - organisation: str, - model_topology: str - ): - """ - Set model ID to latest model in given ontology group. - - :param organism: Identifier of organism to select. - :param organ: Identifier of organ to select. - :param model_type: Identifier of model_type to select. - :param organisation: Identifier of organisation to select. - :param model_topology: Identifier of model_topology to select - :return: - """ - assert model_type in self.ontology.keys(), "model_type requested was not found in ontology" - assert organisation in self.ontology[model_type].keys(), \ - "organisation requested was not found in ontology" - assert model_topology in self.ontology[model_type][organisation].keys(), \ - "model_topology requested was not found in ontology" - - versions = self.versions( - model_type=model_type, - organisation=organisation, - model_topology=model_topology + @property + def topology_container(self) -> TopologyContainer: + # TODO: this ID decomposition to organism is custom to the topologies handled in this package. + organism = self.model_name.split("-")[0] + return TopologyContainer( + topology=TOPOLOGIES[organism][self.model_class][self.model_type][self.model_topology], + topology_id=self.model_version ) - - self.model_type = model_type - self.organisation = organisation - self.model_topology = model_topology # set to model for now, could be organism/organ specific later - - self.model_version = self._order_versions(versions=versions)[0] - self.model_id = '_'.join([ - 'celltype', - self.id, - self.model_type, - self.organisation, - self.model_topology, - self.model_version - ]) diff --git a/sfaira/interface/user_interface.py b/sfaira/interface/user_interface.py index c6f88b018..db4529678 100644 --- a/sfaira/interface/user_interface.py +++ b/sfaira/interface/user_interface.py @@ -11,7 +11,7 @@ from sfaira.data import DatasetInteractive from sfaira.estimators import EstimatorKerasEmbedding, EstimatorKerasCelltype -from sfaira.interface.model_zoo import ModelZooEmbedding, ModelZooCelltype +from sfaira.interface.model_zoo import ModelZoo class UserInterface: @@ -43,8 +43,8 @@ class UserInterface: estimator_celltype: Union[EstimatorKerasCelltype, None] model_kipoi_embedding: Union[None] model_kipoi_celltype: Union[BaseModel, None] - zoo_embedding: Union[ModelZooEmbedding, None] - zoo_celltype: Union[ModelZooCelltype, None] + zoo_embedding: Union[ModelZoo, None] + zoo_celltype: Union[ModelZoo, None] data: Union[anndata.AnnData] model_lookuptable: Union[pd.DataFrame, None] @@ -87,8 +87,8 @@ def __init__( # TODO: workaround to deal with model ids bearing file endings in model lookuptable (as is the case in first sfaira model repo upload) self.model_lookuptable['model_id'] = [i.replace('.h5', '').replace('.data-00000-of-00001', '') for i in self.model_lookuptable['model_id']] - self.zoo_embedding = ModelZooEmbedding(self.model_lookuptable) - self.zoo_celltype = ModelZooCelltype(self.model_lookuptable) + self.zoo_embedding = ModelZoo(model_lookuptable=self.model_lookuptable, model_class="embedding") + self.zoo_celltype = ModelZoo(model_lookuptable=self.model_lookuptable, model_class="celltype") def _load_lookuptable( self, diff --git a/sfaira/train/train_model.py b/sfaira/train/train_model.py index 4462e99b1..b51ff99e1 100644 --- a/sfaira/train/train_model.py +++ b/sfaira/train/train_model.py @@ -5,9 +5,10 @@ import pickle from typing import Union +from sfaira.consts import AdataIdsSfaira from sfaira.data import DistributedStore, Universe from sfaira.estimators import EstimatorKeras, EstimatorKerasCelltype, EstimatorKerasEmbedding -from sfaira.interface import ModelZooEmbedding, ModelZooCelltype +from sfaira.interface import ModelZoo class TrainModel: @@ -33,16 +34,26 @@ def __init__( self.data = data else: raise ValueError(f"did not recongize data of type {type(data)}") + self.zoo = ModelZoo() + + def load_into_memory(self): + """ + Loads backed objects from DistributedStore into single adata object in memory in .data slot. + :return: + """ + if isinstance(self.data, DistributedStore): + self.data = self.data.adata @abc.abstractmethod def init_estim(self): pass @abc.abstractmethod - def _save_specific( - self, - fn: str - ): + def save_eval(self, fn: str): + pass + + @abc.abstractmethod + def _save_specific(self, fn: str): pass def save( @@ -78,7 +89,6 @@ def __init__( data: Union[str, anndata.AnnData, Universe, DistributedStore], ): super(TrainModelEmbedding, self).__init__(data=data) - self.zoo = ModelZooEmbedding(model_lookuptable=None) self.estimator = None self.model_dir = model_path @@ -91,14 +101,11 @@ def init_estim( data=self.data, model_dir=self.model_dir, model_id=self.zoo.model_id, - model_topology=self.zoo.model_topology + model_topology=self.zoo.topology_container ) self.estimator.init_model(override_hyperpar=override_hyperpar) - def save_eval( - self, - fn: str - ): + def save_eval(self, fn: str): evaluation_train = self.estimator.evaluate_any(idx=self.estimator.idx_train) evaluation_val = self.estimator.evaluate_any(idx=self.estimator.idx_eval) evaluation_test = self.estimator.evaluate_any(idx=self.estimator.idx_test) @@ -112,10 +119,7 @@ def save_eval( with open(fn + '_evaluation.pickle', 'wb') as f: pickle.dump(obj=evaluation, file=f) - def _save_specific( - self, - fn: str - ): + def _save_specific(self, fn: str): """ Save embedding prediction: @@ -123,10 +127,7 @@ def _save_specific( :return: """ embedding = self.estimator.predict_embedding() - df_summary = self.estimator.obs_test[ - ["dataset", "cell_ontology_class", "state_exact", "author", "year", "assay_sc", - "assay_differentiation", "assay_type_differentiation", "cell_line", "sample_source"] - ] + df_summary = self.estimator.obs_test[AdataIdsSfaira.obs_keys] df_summary["ncounts"] = np.asarray( self.estimator.data.X[np.sort(self.estimator.idx_test), :].sum(axis=1)[np.argsort(self.estimator.idx_test)] ).flatten() @@ -145,7 +146,6 @@ def __init__( fn_target_universe: str, ): super(TrainModelCelltype, self).__init__(data=data) - self.zoo = ModelZooCelltype(model_lookuptable=None) self.estimator = None self.model_dir = model_path self.data.celltypes_universe.load_target_universe(fn=fn_target_universe) @@ -159,14 +159,11 @@ def init_estim( data=self.data, model_dir=self.model_dir, model_id=self.zoo.model_id, - model_topology=self.zoo.model_topology + model_topology=self.zoo.topology_container ) self.estimator.init_model(override_hyperpar=override_hyperpar) - def save_eval( - self, - fn: str - ): + def save_eval(self, fn: str): evaluation = { 'train': self.estimator.evaluate_any(idx=self.estimator.idx_train, weighted=False), 'val': self.estimator.evaluate_any(idx=self.estimator.idx_eval, weighted=False), @@ -184,10 +181,7 @@ def save_eval( with open(fn + '_evaluation_weighted.pickle', 'wb') as f: pickle.dump(obj=evaluation_weighted, file=f) - def _save_specific( - self, - fn: str - ): + def _save_specific(self, fn: str): """ Save true and predicted labels on test set: @@ -196,10 +190,7 @@ def _save_specific( """ ytrue = self.estimator.ytrue() yhat = self.estimator.predict() - df_summary = self.estimator.obs_test[ - ["dataset", "cell_ontology_class", "state_exact", "author", "year", "assay_sc", - "assay_differentiation", "assay_type_differentiation", "cell_line", "sample_source"] - ] + df_summary = self.estimator.obs_test[AdataIdsSfaira.obs_keys] df_summary["ncounts"] = np.asarray(self.estimator.data.X[self.estimator.idx_test, :].sum(axis=1)).flatten() np.save(file=fn + "_ytrue", arr=ytrue) np.save(file=fn + "_yhat", arr=yhat) @@ -207,16 +198,16 @@ def _save_specific( with open(fn + '_ontology_names.pickle', 'wb') as f: pickle.dump(obj=self.estimator.ids, file=f) - cell_counts = self.data.obs_concat(keys=['cell_ontology_class'])['cell_ontology_class'].value_counts().to_dict() + cell_counts = self.data.obs['cell_ontology_class'].value_counts().to_dict() cell_counts_leaf = cell_counts.copy() for k in cell_counts.keys(): if k not in self.estimator.ids: - if k not in self.estimator.celltype_universe.ontology.node_ids: + if k not in self.estimator.celltype_universe.onto_cl.node_ids: raise(ValueError(f"Celltype '{k}' not found in celltype universe")) - for leaf in self.estimator.celltype_universe.ontology.node_ids: + for leaf in self.estimator.celltype_universe.onto_cl.node_ids: if leaf not in cell_counts_leaf.keys(): cell_counts_leaf[leaf] = 0 - cell_counts_leaf[leaf] += 1 / len(self.estimator.celltype_universe.ontology.node_ids) + cell_counts_leaf[leaf] += 1 / len(self.estimator.celltype_universe.onto_cl.node_ids) del cell_counts_leaf[k] with open(fn + '_celltypes_valuecounts_wholedata.pickle', 'wb') as f: pickle.dump(obj=[cell_counts, cell_counts_leaf], file=f) diff --git a/sfaira/unit_tests/data/test_dataset.py b/sfaira/unit_tests/data/test_dataset.py index deab6f1bc..940b22f0a 100644 --- a/sfaira/unit_tests/data/test_dataset.py +++ b/sfaira/unit_tests/data/test_dataset.py @@ -1,15 +1,14 @@ import numpy as np import os import pytest -import scipy.sparse from sfaira.data import DatasetSuperGroup from sfaira.data import Universe MOUSE_GENOME_ANNOTATION = "Mus_musculus.GRCm38.102" -dir_data = "../test_data" -dir_meta = "../test_data/meta" +dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") +dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") def test_dsgs_instantiate(): diff --git a/sfaira/unit_tests/data/test_store.py b/sfaira/unit_tests/data/test_store.py index 2d50aa9e6..b89a29d22 100644 --- a/sfaira/unit_tests/data/test_store.py +++ b/sfaira/unit_tests/data/test_store.py @@ -1,14 +1,18 @@ import numpy as np import os import pytest +import time +from typing import List from sfaira.data import DistributedStore -from sfaira.data import Universe +from sfaira.versions.genomes import GenomeContainer +from sfaira.unit_tests.utils import cached_store_writing + MOUSE_GENOME_ANNOTATION = "Mus_musculus.GRCm38.102" -dir_data = "../test_data" -dir_meta = "../test_data/meta" +dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") +dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") """ @@ -16,24 +20,14 @@ """ -def test_store_config(): +def test_config(): """ Test that data set config files can be set, written and recovered. """ - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) - ds.subset(key="organism", values=["mouse"]) - ds.subset(key="organ", values=["lung"]) - ds.load() - ds.streamline_features(remove_gene_version=True, match_to_reference={"mouse": MOUSE_GENOME_ANNOTATION}, - subset_genes_to_type="protein_coding") - ds.streamline_metadata(schema="sfaira", uns_to_obs=False, clean_obs=True, clean_var=True, clean_uns=True, - clean_obs_names=True) - store_path = os.path.join(dir_data, "store") + store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=MOUSE_GENOME_ANNOTATION) config_path = os.path.join(store_path, "lung") - ds.write_distributed_store(dir_cache=store_path, store="h5ad", dense=True) store = DistributedStore(cache_path=store_path) store.subset(attr_key="assay_sc", values=["10x sequencing"]) - store.subset_cells(attr_key="assay_sc", values=["10x sequencing"]) store.write_config(fn=config_path) store2 = DistributedStore(cache_path=store_path) store2.load_config(fn=config_path) @@ -41,21 +35,12 @@ def test_store_config(): assert np.all([np.all(store.indices[k] == store2.indices[k]) for k in store.indices.keys()]) -def test_store_type_targets(): +def test_type_targets(): """ Test that target leave nodes can be set, written and recovered. """ - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) - ds.subset(key="organism", values=["mouse"]) - ds.subset(key="organ", values=["lung"]) - ds.load() - ds.streamline_features(remove_gene_version=True, match_to_reference={"mouse": MOUSE_GENOME_ANNOTATION}, - subset_genes_to_type="protein_coding") - ds.streamline_metadata(schema="sfaira", uns_to_obs=False, clean_obs=True, clean_var=True, clean_uns=True, - clean_obs_names=True) - store_path = os.path.join(dir_data, "store") + store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=MOUSE_GENOME_ANNOTATION) target_path = os.path.join(store_path, "lung") - ds.write_distributed_store(dir_cache=store_path, store="h5ad", dense=True) store = DistributedStore(cache_path=store_path) observed_nodes = np.unique(np.concatenate([ x.obs[store._adata_ids_sfaira.cell_ontology_class] @@ -72,3 +57,48 @@ def test_store_type_targets(): assert len(leaves_all) > len(leaves1) assert len(set(leaves1).union(set(leaves2))) == len(leaves1) assert np.all([x in leaves1 for x in leaves2]) + + +@pytest.mark.parametrize("idx", [None, np.concatenate([np.arange(150, 200), np.array([1, 100, 2003, 33])])]) +@pytest.mark.parametrize("batch_size", [1, 10]) +@pytest.mark.parametrize("obs_keys", [[], ["cell_ontology_class"]]) +@pytest.mark.parametrize("gc", [(None, {}), (MOUSE_GENOME_ANNOTATION, {"biotype": "protein_coding"})]) +def test_generator_shapes(idx, batch_size: int, obs_keys: List[str], gc: tuple): + """ + Test generators queries do not throw errors and that output shapes are correct. + """ + assembly, subset = gc + store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=MOUSE_GENOME_ANNOTATION) + store = DistributedStore(cache_path=store_path) + if assembly is not None: + gc = GenomeContainer(assembly=assembly) + gc.subset(**subset) + store.genome_container = gc + g = store.generator( + idx=idx, + batch_size=batch_size, + obs_keys=obs_keys, + ) + nobs = len(idx) if idx is not None else store.n_obs + batch_sizes = [] + t0 = time.time() + for i, z in enumerate(g()): + x_i, obs_i = z + assert x_i.shape[0] == obs_i.shape[0] + if i == 0: # First batch hast correct shape, last batch not necessarily! + x = x_i + obs = obs_i + batch_sizes.append(x_i.shape[0]) + tdelta = time.time() - t0 + print(f"time for iterating over generator:" + f" {tdelta}s for {np.sum(batch_sizes)} cells in {len(batch_sizes)} batches," + f" {tdelta / len(batch_sizes)}s per batch.") + # Only the last batch in each data set can be of different size: + assert np.sum(batch_sizes != batch_size) <= len(store.adatas.keys()) + assert x.shape[0] == batch_size, (x.shape, batch_size) + assert obs.shape[0] == batch_size, (obs.shape, batch_size) + assert x.shape[1] == store.n_vars, (x.shape, store.n_vars) + assert obs.shape[1] == len(obs_keys), (x.shape, obs_keys) + assert np.sum(batch_sizes) == nobs, (x.shape, obs_keys) + if assembly is not None: + assert x.shape[1] == gc.n_var, (x.shape, gc.n_var) diff --git a/sfaira/unit_tests/estimators/test_estimator.py b/sfaira/unit_tests/estimators/test_estimator.py index 2fe831ec4..521efd9d0 100644 --- a/sfaira/unit_tests/estimators/test_estimator.py +++ b/sfaira/unit_tests/estimators/test_estimator.py @@ -1,18 +1,32 @@ import abc import anndata import numpy as np +import os +import pandas as pd +import pytest +import time from typing import Union +from sfaira.data import DistributedStore from sfaira.estimators import EstimatorKeras, EstimatorKerasCelltype, EstimatorKerasEmbedding from sfaira.versions.topologies import TopologyContainer +from sfaira.unit_tests.utils import cached_store_writing, simulate_anndata +dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") +dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") +cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), + "cache", "genomes") + +ASSEMBLY = "Mus_musculus.GRCm38.102" GENES = ["ENSMUSG00000000003", "ENSMUSG00000000028"] TARGETS = ["T cell", "stromal cell"] +ASSAYS = ["10x sequencing", "Smart-seq2"] + TOPOLOGY_EMBEDDING_MODEL = { "model_type": None, "input": { - "genome": "Mus_musculus.GRCm38.102", + "genome": ASSEMBLY, "genes": ["ensg", GENES], }, "output": {}, @@ -27,7 +41,7 @@ TOPOLOGY_CELLTYPE_MODEL = { "model_type": None, "input": { - "genome": "Mus_musculus.GRCm38.102", + "genome": ASSEMBLY, "genes": ["ensg", GENES], }, "output": { @@ -43,8 +57,11 @@ class HelperEstimatorBase: + + data: Union[anndata.AnnData, DistributedStore] estimator: Union[EstimatorKeras] - data: Union[anndata.AnnData] + model_type: str + tc: TopologyContainer """ Contains functions _test* to test individual functions and attributes of estimator class. @@ -53,74 +70,102 @@ class HelperEstimatorBase: basic_estimator_test(). See _test_call() for an example. """ - def simulate(self): + def _simulate(self) -> anndata.AnnData: """ Simulate basic data example used for unit test. - Sets attribute .data with simulated data. + :return: Simulated data set. + """ + return simulate_anndata(n_obs=100, assays=ASSAYS, genes=self.tc.gc.ensembl, targets=TARGETS) - :return: + def load_adata(self): """ - nobs = 100 - self.data = anndata.AnnData( - np.random.randint(low=0, high=100, size=(nobs, len(GENES))).astype(np.float32) - ) - self.data.obs["cell_ontology_class"] = [ - TARGETS[np.random.randint(0, len(TARGETS))] - for i in range(nobs) - ] - self.data.var["ensembl"] = GENES + Sets attribute .data with simulated data. + """ + self.data = self._simulate() + + def load_store(self): + store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=ASSEMBLY) + store = DistributedStore(cache_path=store_path) + self.data = store + + @abc.abstractmethod + def init_topology(self, model_type: str, feature_space: str): + pass @abc.abstractmethod - def init_estimator(self, model_type: str): + def init_estimator(self): """ Initialise target estimator as .estimator attribute. """ pass + def estimator_train(self, test_split): + self.estimator.init_model() + self.estimator.train( + optimizer="adam", + lr=0.005, + epochs=2, + batch_size=4, + validation_split=0.5, + test_split=test_split, + validation_batch_size=4, + max_validation_steps=1, + shuffle_buffer_size=10, + cache_full=False, + ) + @abc.abstractmethod - def basic_estimator_test(self): + def basic_estimator_test(self, test_split): pass - def fatal_estimator_test(self, model_type): + def load_estimator(self, model_type, data_type, feature_space, test_split): + self.init_topology(model_type=model_type, feature_space=feature_space) np.random.seed(1) - self.simulate() - self.init_estimator(model_type=model_type) + if data_type == "adata": + self.load_adata() + else: + self.load_store() + self.init_estimator() + self.estimator_train(test_split=test_split) + + def fatal_estimator_test(self, model_type, data_type, test_split=0.1, feature_space="small"): + self.load_estimator(model_type=model_type, data_type=data_type, feature_space=feature_space, + test_split=test_split) self.basic_estimator_test() - return True class HelperEstimatorKerasEmbedding(HelperEstimatorBase): estimator: EstimatorKerasEmbedding + model_type: str + tc: TopologyContainer - def init_estimator(self, model_type): + def init_topology(self, model_type: str, feature_space: str): topology = TOPOLOGY_EMBEDDING_MODEL.copy() + if feature_space == "full": + # Read 500 genes (not full protein coding) to compromise between being able to distinguish observations + # and reducing run time of unit tests. + tab = pd.read_csv(os.path.join(cache_dir, ASSEMBLY + ".csv")) + genes_full = tab.loc[tab["gene_biotype"].values == "protein_coding", "gene_id"].values[:500].tolist() + topology["input"]["genes"] = ["ensg", genes_full] topology["model_type"] = model_type if model_type == "linear": topology["hyper_parameters"]["latent_dim"] = 2 else: - topology["hyper_parameters"]["latent_dim"] = (len(GENES), 2, len(GENES)) + topology["hyper_parameters"]["latent_dim"] = (2, 2, 2) self.model_type = model_type + self.tc = TopologyContainer(topology=topology, topology_id="0.1") + + def init_estimator(self): self.estimator = EstimatorKerasEmbedding( data=self.data, model_dir=None, model_id="testid", - model_topology=TopologyContainer(topology=topology, topology_id="0.1") + model_topology=self.tc ) - def basic_estimator_test(self): - self.estimator.init_model() - self.estimator.train( - optimizer="adam", - lr=0.005, - epochs=2, - batch_size=32, - validation_split=0.1, - test_split=0.1, - validation_batch_size=32, - max_validation_steps=1 - ) + def basic_estimator_test(self, test_split=0.1): _ = self.estimator.evaluate() prediction_output = self.estimator.predict() prediction_embed = self.estimator.predict_embedding() @@ -131,41 +176,37 @@ def basic_estimator_test(self): new_prediction_embed = self.estimator.predict_embedding() new_weights = self.estimator.model.training_model.get_weights() for i in range(len(weights)): - assert np.allclose(weights[i], new_weights[i], rtol=1e-6, atol=1e-6) + if not np.any(np.isnan(weights[i])): + assert np.allclose(weights[i], new_weights[i], rtol=1e-6, atol=1e-6) if self.model_type != 'vae': - assert np.allclose(prediction_output, new_prediction_output, rtol=1e-6, atol=1e-6) - assert np.allclose(prediction_embed, new_prediction_embed, rtol=1e-6, atol=1e-6) + if not np.any(np.isnan(prediction_output)): + assert np.allclose(prediction_output, new_prediction_output, rtol=1e-6, atol=1e-6) + assert np.allclose(prediction_embed, new_prediction_embed, rtol=1e-6, atol=1e-6) class HelperEstimatorKerasCelltype(HelperEstimatorBase): estimator: EstimatorKerasCelltype + model_type: str + tc: TopologyContainer - def init_estimator(self, model_type: str): + def init_topology(self, model_type: str, feature_space: str): topology = TOPOLOGY_CELLTYPE_MODEL.copy() topology["model_type"] = model_type - topology["hyper_parameters"]["latent_dim"] = (len(GENES), 2) + topology["hyper_parameters"]["latent_dim"] = (2,) self.model_type = model_type + self.tc = TopologyContainer(topology=topology, topology_id="0.1") + + def init_estimator(self): self.estimator = EstimatorKerasCelltype( data=self.data, model_dir=None, model_id="testid", - model_topology=TopologyContainer(topology=topology, topology_id="0.1"), + model_topology=self.tc ) self.estimator.celltype_universe.leaves = TARGETS - def basic_estimator_test(self): - self.estimator.init_model() - self.estimator.train( - optimizer="adam", - lr=0.005, - epochs=2, - batch_size=32, - validation_split=0.1, - test_split=0.1, - validation_batch_size=32, - max_validation_steps=1 - ) + def basic_estimator_test(self, test_split=0.1): _ = self.estimator.evaluate() prediction_output = self.estimator.predict() weights = self.estimator.model.training_model.get_weights() @@ -175,36 +216,165 @@ def basic_estimator_test(self): new_weights = self.estimator.model.training_model.get_weights() print(self.estimator.model.training_model.summary()) for i in range(len(weights)): - assert np.allclose(weights[i], new_weights[i], rtol=1e-6, atol=1e-6) - assert np.allclose(prediction_output, new_prediction_output, rtol=1e-6, atol=1e-6) + if not np.any(np.isnan(weights[i])): + assert np.allclose(weights[i], new_weights[i], rtol=1e-6, atol=1e-6) + if not np.any(np.isnan(prediction_output)): + assert np.allclose(prediction_output, new_prediction_output, rtol=1e-6, atol=1e-6) # Test embedding models: -def test_for_fatal_linear(): +@pytest.mark.parametrize("data_type", ["adata", "store"]) +def test_for_fatal_linear(data_type): test_estim = HelperEstimatorKerasEmbedding() - test_estim.fatal_estimator_test(model_type="linear") + test_estim.fatal_estimator_test(model_type="linear", data_type=data_type) def test_for_fatal_ae(): test_estim = HelperEstimatorKerasEmbedding() - test_estim.fatal_estimator_test(model_type="ae") + test_estim.fatal_estimator_test(model_type="ae", data_type="adata") def test_for_fatal_vae(): test_estim = HelperEstimatorKerasEmbedding() - test_estim.fatal_estimator_test(model_type="vae") + test_estim.fatal_estimator_test(model_type="vae", data_type="adata") # Test cell type predictor models: -def test_for_fatal_mlp(): +@pytest.mark.parametrize("data_type", ["adata", "store"]) +def test_for_fatal_mlp(data_type): test_estim = HelperEstimatorKerasCelltype() - test_estim.fatal_estimator_test(model_type="mlp") + test_estim.fatal_estimator_test(model_type="mlp", data_type=data_type) def test_for_fatal_marker(): test_estim = HelperEstimatorKerasCelltype() - test_estim.fatal_estimator_test(model_type="marker") + test_estim.fatal_estimator_test(model_type="marker", data_type="adata") + + +# Test index sets + + +@pytest.mark.parametrize("data_type", ["adata", "store"]) +@pytest.mark.parametrize("test_split", [0.3, {"assay_sc": "10x sequencing"}]) +def test_split_index_sets(data_type: str, test_split): + """ + Test that train, val, test split index sets are correct: + + 1) complete + 2) non-overlapping + 3) that test indices map to all (float split) or distinct (attribute split) data sets + 4) do not contain duplicated observations within and across splits (defined based on the feature vectors) + """ + test_estim = HelperEstimatorKerasEmbedding() + # Need full feature space here because observations are not necessarily different in small model testing feature + # space with only two genes: + t0 = time.time() + test_estim.load_estimator(model_type="linear", data_type=data_type, test_split=test_split, feature_space="full") + print(f"time for running estimator test: {time.time() - t0}s") + idx_train = test_estim.estimator.idx_train + idx_eval = test_estim.estimator.idx_eval + idx_test = test_estim.estimator.idx_test + # 1) Assert that index assignments sum up to full data set: + assert len(idx_train) + len(idx_eval) + len(idx_test) == test_estim.data.n_obs, \ + (len(idx_train), len(idx_eval), len(idx_test), test_estim.data.n_obs) + # 2) Assert that index assignments are exclusive to each split: + assert len(set(idx_train).intersection(set(idx_eval))) == 0 + assert len(set(idx_train).intersection(set(idx_test))) == 0 + assert len(set(idx_test).intersection(set(idx_eval))) == 0 + # 3) Check partition of index vectors over store data sets matches test split scenario: + if isinstance(test_estim.estimator.data, DistributedStore): + # Prepare data set-wise index vectors that are numbered in the same way as global split index vectors. + # See also EstimatorKeras.train and DistributedStore.subset_cells_idx_global + idx_raw = test_estim.estimator.data.indices_global + if isinstance(test_split, float): + # Make sure that indices from each split are in each data set: + for z in [idx_train, idx_eval, idx_test]: + assert np.all([ # in each data set + np.any([y in z for y in x]) # at least one match of data set to split index set + for x in idx_raw + ]) + else: + # Make sure that indices from (train, val) and test split are exclusive: + datasets_train = np.where([ # in each data set + np.any([y in idx_train for y in x]) # at least one match of data set to split index set + for x in idx_raw + ])[0] + datasets_eval = np.where([ # in each data set + np.any([y in idx_eval for y in x]) # at least one match of data set to split index set + for x in idx_raw + ])[0] + datasets_test = np.where([ # in each data set + np.any([y in idx_test for y in x]) # at least one match of data set to split index set + for x in idx_raw + ])[0] + assert datasets_train == datasets_eval, (datasets_train, datasets_eval) + assert len(set(datasets_train).intersection(set(datasets_test))) == 0, (datasets_train, datasets_test) + # 4) Assert that observations mapped to indices are actually unique based on expression vectors: + # Build numpy arrays of expression input data sets from tensorflow data sets directly from estimator. + # These data sets are the most processed transformation of the data and stand directly in concat with the model. + t0 = time.time() + ds_train = test_estim.estimator._get_dataset(idx=idx_train, batch_size=128, mode='eval', shuffle_buffer_size=1, + retrieval_batch_size=128) + print(f"time for building training data set: {time.time() - t0}s") + t0 = time.time() + ds_eval = test_estim.estimator._get_dataset(idx=idx_eval, batch_size=128, mode='eval', shuffle_buffer_size=1, + retrieval_batch_size=128) + print(f"time for building validation data set: {time.time() - t0}s") + t0 = time.time() + ds_test = test_estim.estimator._get_dataset(idx=idx_test, batch_size=128, mode='eval', shuffle_buffer_size=1, + retrieval_batch_size=128) + print(f"time for building test data set: {time.time() - t0}s") + x_train = [] + x_eval = [] + x_test = [] + t0 = time.time() + for x, y in ds_train.as_numpy_iterator(): + x_train.append(x[0]) + x_train = np.concatenate(x_train, axis=0) + print(f"time for iterating over training data set: {time.time() - t0}s") + t0 = time.time() + for x, y in ds_eval.as_numpy_iterator(): + x_eval.append(x[0]) + x_eval = np.concatenate(x_eval, axis=0) + print(f"time for iterating over validation data set: {time.time() - t0}s") + t0 = time.time() + for x, y in ds_test.as_numpy_iterator(): + x_test.append(x[0]) + x_test = np.concatenate(x_test, axis=0) + print(f"time for iterating over test data set: {time.time() - t0}s") + # Validate size of recovered numpy data sets: + print(f"shapes received {(x_train.shape[0], x_eval.shape[0], x_test.shape[0])}") + print(f"shapes expected {(len(idx_train), len(idx_eval), len(idx_test))}") + assert x_train.shape[0] == len(idx_train) + assert x_eval.shape[0] == len(idx_eval) + assert x_test.shape[0] == len(idx_test) + # Assert that observations are unique within partition: + assert np.all([ + np.sum([np.all(x_train[i] == x_train[j]) for j in range(x_train.shape[0])]) == 1 + for i in range(x_train.shape[0]) + ]) + assert np.all([ + np.sum([np.all(x_eval[i] == x_eval[j]) for j in range(x_eval.shape[0])]) == 1 + for i in range(x_eval.shape[0]) + ]) + assert np.all([ + np.sum([np.all(x_test[i] == x_test[j]) for j in range(x_test.shape[0])]) == 1 + for i in range(x_test.shape[0]) + ]) + # Assert that observations are not replicated across partitions: + assert not np.any([ + np.any([np.all(x_train[i] == x_eval[j]) for j in range(x_eval.shape[0])]) + for i in range(x_train.shape[0]) + ]) + assert not np.any([ + np.any([np.all(x_train[i] == x_test[j]) for j in range(x_test.shape[0])]) + for i in range(x_train.shape[0]) + ]) + assert not np.any([ + np.any([np.all(x_test[i] == x_eval[j]) for j in range(x_eval.shape[0])]) + for i in range(x_test.shape[0]) + ]) diff --git a/sfaira/unit_tests/interface/test_userinterface.py b/sfaira/unit_tests/interface/test_userinterface.py index 504e91984..613d016fe 100644 --- a/sfaira/unit_tests/interface/test_userinterface.py +++ b/sfaira/unit_tests/interface/test_userinterface.py @@ -1,12 +1,11 @@ import numpy as np import os from typing import Union -import unittest from sfaira.interface import UserInterface -class TestUi(unittest.TestCase): +class TestUi: ui: Union[UserInterface] data: np.ndarray @@ -27,7 +26,7 @@ def simulate(self): """ pass - def test_basic(self): + def _test_basic(self): """ Test all relevant model methods. @@ -47,7 +46,3 @@ def _test_kipoi(self): temp_fn = os.path.join(str(os.path.dirname(os.path.abspath(__file__))), '../test_data') self.ui = UserInterface(custom_repo=temp_fn, sfaira_repo=False) self.ui.compute_embedding_kipoi() - - -if __name__ == '__main__': - unittest.main() diff --git a/sfaira/unit_tests/interface/test_zoo.py b/sfaira/unit_tests/interface/test_zoo.py new file mode 100644 index 000000000..2a8f2bd30 --- /dev/null +++ b/sfaira/unit_tests/interface/test_zoo.py @@ -0,0 +1,31 @@ +import os +from sfaira.interface import ModelZoo + +dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") +dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") + + +def test_for_fatal_embedding(): + model_id = "embedding_human-lung-linear-0.1-0.1_mylab" + zoo = ModelZoo() + zoo.model_id = model_id + assert zoo.model_id == model_id + assert zoo.model_class == "embedding" + assert zoo.model_name == "human-lung-linear-0.1-0.1" + assert zoo.organisation == "mylab" + _ = zoo.topology_container + _ = zoo.topology_container.topology + _ = zoo.topology_container.gc + + +def test_for_fatal_celltype(): + model_id = "celltype_human-lung-mlp-0.0.1-0.1_mylab" + zoo = ModelZoo() + zoo.model_id = model_id + assert zoo.model_id == model_id + assert zoo.model_class == "celltype" + assert zoo.model_name == "human-lung-mlp-0.0.1-0.1" + assert zoo.organisation == "mylab" + _ = zoo.topology_container + _ = zoo.topology_container.topology + _ = zoo.topology_container.gc diff --git a/sfaira/unit_tests/trainer/__init__.py b/sfaira/unit_tests/trainer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sfaira/unit_tests/trainer/test_trainer.py b/sfaira/unit_tests/trainer/test_trainer.py new file mode 100644 index 000000000..c9886c5f0 --- /dev/null +++ b/sfaira/unit_tests/trainer/test_trainer.py @@ -0,0 +1,78 @@ +import anndata +import numpy as np +import os +import pytest +from typing import Union + +from sfaira.data import DistributedStore +from sfaira.interface import ModelZoo, ModelZooCelltype, ModelZooEmbedding +from sfaira.train import TrainModelCelltype, TrainModelEmbedding +from sfaira.unit_tests.utils import cached_store_writing, simulate_anndata + +dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") +dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") + +ASSEMBLY = "Mus_musculus.GRCm38.102" +TARGETS = ["T cell", "stromal cell"] + + +class HelperTrainerBase: + + data: Union[anndata.AnnData, DistributedStore] + trainer: Union[TrainModelCelltype, TrainModelEmbedding] + zoo: ModelZoo + + def __init__(self, zoo: ModelZoo): + self.model_id = zoo.model_id + self.tc = zoo.topology_container + + def _simulate(self) -> anndata.AnnData: + """ + Simulate basic data example used for unit test. + + :return: Simulated data set. + """ + return simulate_anndata(n_obs=100, genes=self.tc.gc.ensembl, targets=TARGETS) + + def load_adata(self): + """ + Sets attribute .data with simulated data. + """ + self.data = self._simulate() + + def load_store(self): + store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=ASSEMBLY) + store = DistributedStore(cache_path=store_path) + self.data = store + + def load_data(self, data_type): + np.random.seed(1) + if data_type == "adata": + self.load_adata() + else: + self.load_store() + + def test_for_fatal(self, cls): + self.load_data(data_type="adata") + trainer = cls( + data=self.data, + model_path=dir_meta, + ) + trainer.zoo.set_model_id(model_id=self.model_id) + trainer.init_estim(override_hyperpar={}) + + +def test_for_fatal_embedding(): + model_id = "embedding_human-lung_linear_mylab_0.1_0.1" + zoo = ModelZooEmbedding() + zoo.set_model_id(model_id=model_id) + test_trainer = HelperTrainerBase(zoo=zoo) + test_trainer.test_for_fatal(cls=TrainModelEmbedding) + + +def test_for_fatal(): + model_id = "celltype_human-lung_mlp_mylab_0.0.1_0.1" + zoo = ModelZooCelltype() + zoo.set_model_id(model_id=model_id) + test_trainer = HelperTrainerBase(zoo=zoo) + test_trainer.test_for_fatal(cls=TrainModelCelltype) diff --git a/sfaira/unit_tests/utils.py b/sfaira/unit_tests/utils.py new file mode 100644 index 000000000..32d3d03fb --- /dev/null +++ b/sfaira/unit_tests/utils.py @@ -0,0 +1,53 @@ +import anndata +import numpy as np +import os + +from sfaira.data import Universe + + +def simulate_anndata(genes, n_obs, targets=None, assays=None) -> anndata.AnnData: + """ + Simulate basic data example. + + :return: AnnData instance. + """ + data = anndata.AnnData( + np.random.randint(low=0, high=100, size=(n_obs, len(genes))).astype(np.float32) + ) + if assays is not None: + data.obs["assay_sc"] = [ + assays[np.random.randint(0, len(targets))] + for i in range(n_obs) + ] + if targets is not None: + data.obs["cell_ontology_class"] = [ + targets[np.random.randint(0, len(targets))] + for i in range(n_obs) + ] + data.var["ensembl"] = genes + return data + + +def cached_store_writing(dir_data, dir_meta, assembly) -> os.PathLike: + """ + Writes a store if it does not already exist. + + :return: Path to store. + """ + store_path = os.path.join(dir_data, "store") + ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds.subset(key="organism", values=["mouse"]) + ds.subset(key="organ", values=["lung"]) + # Only load files that are not already in cache. + anticipated_files = np.unique([ + v.doi for k, v in ds.datasets.items() + if not os.path.exists(os.path.join(store_path, v.doi_cleaned_id + ".h5ad")) + ]).tolist() + ds.subset(key="doi", values=anticipated_files) + ds.load(allow_caching=True) + ds.streamline_features(remove_gene_version=True, match_to_reference={"mouse": assembly}, + subset_genes_to_type="protein_coding") + ds.streamline_metadata(schema="sfaira", uns_to_obs=True, clean_obs=True, clean_var=True, clean_uns=True, + clean_obs_names=True) + ds.write_distributed_store(dir_cache=store_path, store="h5ad", dense=False) + return store_path