From 82e02b6bdfaf29a865cff9b4a83e3220ad669711 Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Tue, 15 Jun 2021 14:55:53 +0200 Subject: [PATCH 01/15] Cellxgene export (#315) * updated count rounding warning in streamlining * improved meta data streamlining * updated DOIs to distinguish preprint and journal --- sfaira/consts/adata_fields.py | 18 +- sfaira/consts/ontologies.py | 3 +- sfaira/data/base/dataset.py | 282 +++++++++++------- sfaira/data/base/dataset_group.py | 4 +- ...letoflangerhans_2017_smartseq2_enge_001.py | 3 +- .../mouse_x_2018_microwellseq_han_x.py | 2 +- ...fcolon_2019_10xsequencing_kinchen_001.yaml | 6 +- ...pithelium_2019_10xsequencing_smilie_001.py | 2 +- ...man_ileum_2019_10xsequencing_martin_001.py | 2 +- ...stategland_2018_10xsequencing_henry_001.py | 2 +- .../human_pancreas_2016_indrop_baron_001.py | 2 +- ...pancreas_2016_smartseq2_segerstolpe_001.py | 2 +- ..._pancreas_2019_10xsequencing_thompson_x.py | 2 +- ...uman_lung_2020_10xsequencing_miller_001.py | 2 +- ...an_brain_2019_dropseq_polioudakis_001.yaml | 6 +- .../human_brain_2017_droncseq_habib_001.py | 2 +- ...human_testis_2018_10xsequencing_guo_001.py | 2 +- ...liver_2018_10xsequencing_macparland_001.py | 2 +- .../human_kidney_2019_droncseq_lake_001.py | 2 +- .../human_x_2019_10xsequencing_szabo_001.py | 2 +- ...man_retina_2019_10xsequencing_menon_001.py | 2 +- .../human_placenta_2018_x_ventotormo_001.py | 2 +- .../human_liver_2019_celseq2_aizarani_001.py | 2 +- ...ver_2019_10xsequencing_ramachandran_001.py | 2 +- ...an_liver_2019_10xsequencing_popescu_001.py | 2 +- ...rain_2019_10x3v2sequencing_kanton_001.yaml | 6 +- .../human_x_2020_microwellseq_han_x.py | 2 +- .../human_lung_2020_x_travaglini_001.yaml | 6 +- ...uman_colon_2020_10xsequencing_james_001.py | 2 +- .../human_lung_2019_dropseq_braga_001.py | 2 +- .../human_x_2019_10xsequencing_braga_x.py | 2 +- .../mouse_x_2019_10xsequencing_hove_001.py | 2 +- ...uman_kidney_2020_10xsequencing_liao_001.py | 2 +- ...man_retina_2019_10xsequencing_voigt_001.py | 2 +- .../human_x_2019_10xsequencing_wang_001.py | 2 +- ...an_lung_2020_10xsequencing_lukassen_001.py | 3 +- .../human_blood_2020_10x_hao_001.yaml | 3 +- .../d10_1101_661728/mouse_x_2019_x_pisco_x.py | 3 +- ...nchyma_2020_10xsequencing_habermann_001.py | 3 +- ...n_kidney_2019_10xsequencing_stewart_001.py | 2 +- ...uman_thymus_2020_10xsequencing_park_001.py | 2 +- .../human_x_2020_scirnaseq_cao_001.yaml | 4 +- ...uman_x_2019_10xsequencing_madissoon_001.py | 2 +- ..._retina_2019_10xsequencing_lukowski_001.py | 2 +- ...lood_2019_10xsequencing_10xgenomics_001.py | 2 +- .../human_x_2018_10xsequencing_regev_001.py | 2 +- .../data/utils_scripts/streamline_selected.py | 14 +- sfaira/unit_tests/utils.py | 2 +- 48 files changed, 262 insertions(+), 168 deletions(-) diff --git a/sfaira/consts/adata_fields.py b/sfaira/consts/adata_fields.py index 041fe83e8..9a99fc047 100644 --- a/sfaira/consts/adata_fields.py +++ b/sfaira/consts/adata_fields.py @@ -17,7 +17,8 @@ class AdataIds: cellontology_id: str development_stage: str disease: str - doi: str + doi_journal: str + doi_preprint: str download_url_data: str download_url_meta: str dataset: str @@ -87,7 +88,8 @@ def __init__(self): self.cellontology_id = "cell_ontology_id" self.default_embedding = "default_embedding" self.disease = "disease" - self.doi = "doi" + self.doi_journal = "doi_journal" + self.doi_preprint = "doi_preprint" self.dataset = "dataset" self.dataset_group = "dataset_group" self.download_url_data = "download_url_data" @@ -123,6 +125,7 @@ def __init__(self): self.unknown_celltype_identifier = "UNKNOWN" self.not_a_cell_celltype_identifier = "NOT_A_CELL" self.unknown_metadata_identifier = "unknown" + self.unknown_metadata_ontology_id_identifier = "unknown" self.obs_keys = [ "assay_sc", @@ -152,7 +155,8 @@ def __init__(self): "annotated", "author", "default_embedding", - "doi", + "doi_journal", + "doi_preprint", "download_url_data", "download_url_meta", "id", @@ -182,9 +186,11 @@ def __init__(self): self.cellontology_class = "cell_type" self.cellontology_id = "cell_type_ontology_term_id" self.default_embedding = "default_embedding" - self.doi = "preprint_doi" + self.doi_journal = "publication_doi" + self.doi_preprint = "preprint_doi" self.disease = "disease" self.gene_id_symbols = "gene_symbol" + self.gene_id_ensembl = "ensembl" self.gene_id_index = self.gene_id_symbols self.id = "id" self.ncells = "ncells" @@ -228,11 +234,13 @@ def __init__(self): "tech_sample", ] self.var_keys = [ + "gene_id_ensembl", "gene_id_symbols", ] self.uns_keys = [ + "doi_journal", + "doi_preprint", "default_embedding", - "id", "title", ] # These attributes related to obs and uns keys above are also in the data set attributes that can be diff --git a/sfaira/consts/ontologies.py b/sfaira/consts/ontologies.py index 8f18526ce..76d9f77ff 100644 --- a/sfaira/consts/ontologies.py +++ b/sfaira/consts/ontologies.py @@ -24,7 +24,8 @@ def __init__(self): "mouse": OntologyMmusdv(), } self.disease = OntologyMondo() - self.doi = None + self.doi_journal = None + self.doi_preprint = None self.ethnicity = { "human": None, # TODO OntologyHancestro "mouse": None, diff --git a/sfaira/data/base/dataset.py b/sfaira/data/base/dataset.py index ef1236b8a..9801e8470 100644 --- a/sfaira/data/base/dataset.py +++ b/sfaira/data/base/dataset.py @@ -71,6 +71,10 @@ def clean_string(s): return s +def get_directory_formatted_doi(x: str) -> str: + return "d" + "_".join("_".join("_".join(x.split("/")).split(".")).split("-")) + + class DatasetBase(abc.ABC): adata: Union[None, anndata.AnnData] class_maps: dict @@ -91,7 +95,8 @@ class DatasetBase(abc.ABC): _default_embedding: Union[None, str] _development_stage: Union[None, str] _disease: Union[None, str] - _doi: Union[None, str] + _doi_journal: Union[None, str] + _doi_preprint: Union[None, str] _download_url_data: Union[Tuple[List[None]], Tuple[List[str]], None] _download_url_meta: Union[Tuple[List[None]], Tuple[List[str]], None] _ethnicity: Union[None, str] @@ -198,7 +203,8 @@ def __init__( self._default_embedding = None self._development_stage = None self._disease = None - self._doi = None + self._doi_journal = None + self._doi_preprint = None self._download_url_data = None self._download_url_meta = None self._ethnicity = None @@ -269,7 +275,10 @@ def __init__( assert self.sample_fn in v.keys(), f"did not find key {self.sample_fn} in yamls keys for {k}" setattr(self, k, v[self.sample_fn]) else: # v is a meta-data item - setattr(self, k, v) + try: + setattr(self, k, v) + except AttributeError as e: + raise ValueError(f"An error occured when setting {k} as {v}: {e}") # ID can be set now already because YAML was used as input instead of child class constructor. self.set_dataset_id(idx=yaml_vals["meta"]["dataset_index"]) @@ -391,6 +400,7 @@ def _download_synapse(self, synapse_entity, fn, **kwargs): @property def cache_fn(self): if self.directory_formatted_doi is None or self._directory_formatted_id is None: + # TODO is this case necessary? warnings.warn("Caching enabled, but Dataset.id or Dataset.doi not set. Disabling caching for now.") return None else: @@ -562,13 +572,15 @@ def streamline_features( adata.var columns that are not defined as gene_id_ensembl_var_key or gene_id_symbol_var_key in the dataloader. :param match_to_reference: Which annotation to map the feature space to. Can be: - - str: Provide the name of the annotation in the format Organism.Assembly.Release - - dict: Mapping of organism to name of the annotation (see str format). Chooses annotation for each data set - based on organism annotation. - :param remove_gene_version: Whether to remove the version number after the colon sometimes found in ensembl gene ids. - :param subset_genes_to_type: Type(s) to subset to. Can be a single type or a list of types or None. Types can be: - - None: All genes in assembly. - - "protein_coding": All protein coding genes in assembly. + - str: Provide the name of the annotation in the format Organism.Assembly.Release + - dict: Mapping of organism to name of the annotation (see str format). Chooses annotation for each + data set based on organism annotation. + :param remove_gene_version: Whether to remove the version number after the colon sometimes found in ensembl + gene ids. + :param subset_genes_to_type: Type(s) to subset to. Can be a single type or a list of types or None. + Types can be: + - None: All genes in assembly. + - "protein_coding": All protein coding genes in assembly. """ self.__assert_loaded() @@ -741,34 +753,37 @@ def streamline_metadata( # set var index var_new.index = var_new[adata_target_ids.gene_id_index].tolist() - per_cell_labels = ["cell_types_original", "cellontology_class", "cellontology_id"] - experiment_batch_labels = ["bio_sample", "individual", "tech_sample"] - - # Prepare .obs column name dict (process keys below with other .uns keys if they're set dataset-wide) - obs_cols = {} - for k in adata_target_ids.obs_keys: - # Skip any per-cell labels for now and process them in the next code block - if k in per_cell_labels: - continue - else: - if hasattr(self, f"{k}_obs_key") and getattr(self, f"{k}_obs_key") is not None: - obs_cols[k] = (getattr(self, f"{k}_obs_key"), getattr(adata_target_ids, k)) - else: - adata_target_ids.uns_keys.append(k) - # Prepare new .uns dict: uns_new = {} for k in adata_target_ids.uns_keys: val = getattr(self, k) - while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: # unpack nested lists/tuples + if val is None and hasattr(self, f"{k}_obs_key"): + val = np.sort(self.adata.obs[getattr(self, f"{k}_obs_key")].values.tolist()) + while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: # Unpack nested lists/tuples. val = val[0] uns_new[getattr(adata_target_ids, k)] = val # Prepare new .obs dataframe + experiment_batch_labels = ["bio_sample", "individual", "tech_sample"] + per_cell_labels = ["cell_types_original", "cellontology_class", "cellontology_id"] obs_new = pd.DataFrame(index=self.adata.obs.index) - for k, (old_col, new_col) in obs_cols.items(): + # Handle non-cell type labels: + for k in [x for x in adata_target_ids.obs_keys if x not in per_cell_labels]: + if hasattr(self, f"{k}_obs_key") and getattr(self, f"{k}_obs_key") is not None: + old_col = getattr(self, f"{k}_obs_key") + val = self.adata.obs[old_col].values.tolist() + else: + old_col = None + val = getattr(self, k) + if val is None: + val = self._adata_ids.unknown_metadata_identifier + # Unpack nested lists/tuples: + while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: + val = val[0] + val = [val] * self.adata.n_obs + new_col = getattr(adata_target_ids, k) # Handle batch-annotation columns which can be provided as a combination of columns separated by an asterisk - if k in experiment_batch_labels and "*" in old_col: + if old_col is not None and k in experiment_batch_labels and "*" in old_col: batch_cols = [] for batch_col in old_col.split("*"): if batch_col in self.adata.obs_keys(): @@ -779,46 +794,41 @@ def streamline_metadata( # in .obs set. print(f"WARNING: attribute {new_col} of data set {self.id} was not found in column {batch_col}") # Build a combination label out of all columns used to describe this group. - obs_new[new_col] = [ + val = [ "_".join([str(xxx) for xxx in xx]) for xx in zip(*[self.adata.obs[batch_col].values.tolist() for batch_col in batch_cols]) ] - setattr(self, f"{k}_obs_key", new_col) # update _obs_column attribute of this class to match the new column # All other .obs fields are interpreted below as provided else: - # Search for direct match of the sought-after column name or for attribute specific obs key. - if old_col in self.adata.obs_keys(): - # Include flag in .uns that this attribute is in .obs: - uns_new[new_col] = UNS_STRING_META_IN_OBS - # Remove potential pd.Categorical formatting: - ontology = getattr(self.ontology_container_sfaira, k) if hasattr(self.ontology_container_sfaira, k) else None - if k == "development_stage": - ontology = ontology[self.organism] - if k == "ethnicity": - ontology = ontology[self.organism] - self._value_protection(attr=new_col, allowed=ontology, attempted=np.unique(self.adata.obs[old_col].values).tolist()) - obs_new[new_col] = self.adata.obs[old_col].values.tolist() - del self.adata.obs[old_col] - setattr(self, f"{k}_obs_key", new_col) # update _obs_column attribute of this class to match the new column - else: - # This should not occur in single data set loaders (see warning below) but can occur in - # streamlined data loaders if not all instances of the streamlined data sets have all columns - # in .obs set. - uns_new[new_col] = None - print(f"WARNING: attribute {new_col} of data set {self.id} was not found in column {old_col}") - - # Set cell-wise attributes (.obs): (None so far other than celltypes.) + # Check values for validity: + ontology = getattr(self.ontology_container_sfaira, k) \ + if hasattr(self.ontology_container_sfaira, k) else None + if k == "development_stage": + ontology = ontology[self.organism] + if k == "ethnicity": + ontology = ontology[self.organism] + self._value_protection(attr=new_col, allowed=ontology, attempted=[ + x for x in np.unique(val) + if x not in [ + self._adata_ids.unknown_metadata_identifier, + self._adata_ids.unknown_metadata_ontology_id_identifier, + ] + ]) + obs_new[new_col] = val + setattr(self, f"{k}_obs_key", new_col) # Set cell types: - # Map cell type names from raw IDs to ontology maintained ones: + # Build auxilliary table with cell type information: if self.cell_types_original_obs_key is not None: - obs_cl = self.project_celltypes_to_ontology(copy=True, adata_fields=adata_target_ids) + obs_cl = self.project_celltypes_to_ontology(copy=True, adata_fields=self._adata_ids) else: obs_cl = pd.DataFrame({ - adata_target_ids.cellontology_class: [adata_target_ids.unknown_metadata_identifier] * self.adata.n_obs, - adata_target_ids.cellontology_id: [adata_target_ids.unknown_metadata_identifier] * self.adata.n_obs, - adata_target_ids.cell_types_original: [adata_target_ids.unknown_metadata_identifier] * self.adata.n_obs, + self._adata_ids.cellontology_class: [self._adata_ids.unknown_metadata_identifier] * self.adata.n_obs, + self._adata_ids.cellontology_id: [self._adata_ids.unknown_metadata_identifier] * self.adata.n_obs, + self._adata_ids.cell_types_original: [self._adata_ids.unknown_metadata_identifier] * self.adata.n_obs, }, index=self.adata.obs.index) - obs_new = pd.concat([obs_new, obs_cl], axis=1) + for k in [x for x in per_cell_labels if x in adata_target_ids.obs_keys]: + obs_new[getattr(adata_target_ids, k)] = obs_cl[getattr(self._adata_ids, k)] + del obs_cl # Add new annotation to adata and delete old fields if requested if clean_var: @@ -862,6 +872,31 @@ def streamline_metadata( else: self.adata.uns = {**self.adata.uns, **uns_new} + # Make sure that correct unknown_metadata_identifier is used in .uns, .obs and .var metadata + unknown_old = self._adata_ids.unknown_metadata_identifier + unknown_new = adata_target_ids.unknown_metadata_identifier + self.adata.obs = self.adata.obs.replace({None: unknown_new}) + self.adata.obs = self.adata.obs.replace({unknown_old: unknown_new}) + self.adata.var = self.adata.var.replace({None: unknown_new}) + self.adata.var = self.adata.var.replace({unknown_old: unknown_new}) + for k in self.adata.uns_keys(): + if self.adata.uns[k] is None or self.adata.uns[k] == unknown_old: + self.adata.uns[k] = unknown_new + + # Move all uns annotation to obs columns if requested + if uns_to_obs: + for k, v in self.adata.uns.items(): + if k not in self.adata.obs_keys(): + if v is None: + v = self._adata_ids.unknown_metadata_identifier + # Unpack nested lists/tuples: + while hasattr(v, '__len__') and not isinstance(v, str) and len(v) == 1: + v = v[0] + self.adata.obs[k] = [v for _ in range(self.adata.n_obs)] + # Retain only target uns keys in .uns. + self.adata.uns = dict([(k, v) for k, v in self.adata.uns.items() + if k in [getattr(adata_target_ids, kk) for kk in ["id"]]]) + # Add additional hard-coded description changes for cellxgene schema: if schema == "cellxgene": self.adata.uns["layer_descriptions"] = {"X": "raw"} @@ -869,6 +904,11 @@ def streamline_metadata( "corpora_encoding_version": "0.1.0", "corpora_schema_version": "1.1.0", } + self.adata.uns["contributors"] = { + "name": "sfaira", + "email": "https://github.com/theislab/sfaira/issues", + "institution": "sfaira", + } # TODO port this into organism ontology handling. if self.organism == "mouse": self.adata.uns["organism"] = "Mus musculus" @@ -879,8 +919,11 @@ def streamline_metadata( else: raise ValueError(f"organism {self.organism} currently not supported by cellxgene schema") # Add ontology IDs where necessary (note that human readable terms are also kept): - for k in ["organ", "assay_sc", "disease", "ethnicity", "development_stage"]: - if getattr(adata_target_ids, k) in self.adata.obs.columns: + ontology_cols = ["organ", "assay_sc", "disease", "ethnicity", "development_stage"] + non_ontology_cols = ["sex"] + for k in ontology_cols: + # TODO enable ethinicity once the distinction between ontology for human and None for mouse works. + if getattr(adata_target_ids, k) in self.adata.obs.columns and k != "ethnicity": ontology = getattr(self.ontology_container_sfaira, k) # Disambiguate organism-dependent ontologies: if isinstance(ontology, dict): @@ -889,13 +932,31 @@ def streamline_metadata( ontology=ontology, key_in=getattr(adata_target_ids, k), key_out=getattr(adata_target_ids, k) + "_ontology_term_id", - map_exceptions=[], + map_exceptions=[adata_target_ids.unknown_metadata_identifier], map_exceptions_value=adata_target_ids.unknown_metadata_ontology_id_identifier, ) else: self.adata.obs[getattr(adata_target_ids, k)] = adata_target_ids.unknown_metadata_identifier self.adata.obs[getattr(adata_target_ids, k) + "_ontology_term_id"] = \ adata_target_ids.unknown_metadata_ontology_id_identifier + # Correct unknown cell type entries: + self.adata.obs[getattr(adata_target_ids, "cellontology_class")] = [ + x if x not in [self._adata_ids.unknown_celltype_identifier, + self._adata_ids.not_a_cell_celltype_identifier] + else "native cell" + for x in self.adata.obs[getattr(adata_target_ids, "cellontology_class")]] + self.adata.obs[getattr(adata_target_ids, "cellontology_id")] = [ + x if x not in [self._adata_ids.unknown_celltype_identifier, + self._adata_ids.not_a_cell_celltype_identifier] + else "CL:0000003" + for x in self.adata.obs[getattr(adata_target_ids, "cellontology_id")]] + # Reorder data frame to put ontology columns first: + cellxgene_cols = [getattr(adata_target_ids, x) for x in ontology_cols] + \ + [getattr(adata_target_ids, x) for x in non_ontology_cols] + \ + [getattr(adata_target_ids, x) + "_ontology_term_id" for x in ontology_cols] + self.adata.obs = self.adata.obs[ + cellxgene_cols + [x for x in self.adata.obs.columns if x not in cellxgene_cols] + ] # Adapt var columns naming. if self.organism == "human": gene_id_new = "hgnc_gene_symbol" @@ -911,27 +972,13 @@ def streamline_metadata( # Check if .X is counts: The conversion are based on the assumption that .X is csr. assert isinstance(self.adata.X, scipy.sparse.csr_matrix), type(self.adata.X) count_values = np.unique(np.asarray(self.adata.X.todense())) - is_counts = np.all(count_values % 1. == 0.) - if not is_counts: - print(f"WARNING: not all count entries were counts {is_counts}. rounding.") + if not np.all(count_values % 1. == 0.): + print(f"WARNING: not all count entries were counts, " + f"the maximum deviation from integer is " + f"{np.max([x % 1. if x % 1. < 0.5 else 1. - x % 1. for x in count_values])}. " + f"The count matrix is rounded.") self.adata.X.data = np.rint(self.adata.X.data) - # Make sure that correct unknown_metadata_identifier is used in .uns, .obs and .var metadata - self.adata.obs = self.adata.obs.replace({None: adata_target_ids.unknown_metadata_identifier}) - self.adata.var = self.adata.var.replace({None: adata_target_ids.unknown_metadata_identifier}) - for k in self.adata.uns_keys(): - if self.adata.uns[k] is None: - self.adata.uns[k] = adata_target_ids.unknown_metadata_identifier - - # Move all uns annotation to obs columns if requested - if uns_to_obs: - for k, v in self.adata.uns.items(): - if k not in self.adata.obs_keys(): - self.adata.obs[k] = [v for i in range(self.adata.n_obs)] - # Retain only target uns keys in .uns. - self.adata.uns = dict([(k, v) for k, v in self.adata.uns.items() - if k in [getattr(adata_target_ids, kk) for kk in ["id"]]]) - self._adata_ids = adata_target_ids # set new adata fields to class after conversion self.streamlined_meta = True @@ -1267,7 +1314,7 @@ def citation(self): :return: """ - return [self.author, self.year, self.doi] + return [self.author, self.year, self.doi_journal] # Meta data handling code: Reading, writing and selected properties. Properties are either set in constructor # (and saved in self._somename) or accessed in self.meta. @@ -1515,8 +1562,17 @@ def data_dir(self): return None else: sfaira_path = os.path.join(self.data_dir_base, self.directory_formatted_doi) + # Allow checking in secondary path, named after second DOI associated with study. + # This allows association of raw data already downloaded even after DOI is updated. + if self.doi_preprint is not None: + sfaira_path_secondary = os.path.join(self.data_dir_base, + get_directory_formatted_doi(x=self.doi_preprint)) + else: + sfaira_path_secondary = None if os.path.exists(sfaira_path): return sfaira_path + elif self.doi_preprint is not None and os.path.exists(sfaira_path_secondary): + return sfaira_path_secondary else: return self.data_dir_base @@ -1576,31 +1632,53 @@ def disease(self, x: str): self._disease = x @property - def doi(self) -> Union[str, List[str]]: - if self._doi is not None: - return self._doi - else: - if self.meta is None: - self.load_meta(fn=None) - if self.meta is None or self._adata_ids.doi not in self.meta.columns: - raise ValueError("doi must be set but was neither set in constructor nor in meta data") - return self.meta[self._adata_ids.doi] + def doi_journal(self) -> str: + """ + The prepring publication (secondary) DOI associated with the study. + See also `.doi_journal`. + """ + return self._doi_journal + + @doi_journal.setter + def doi_journal(self, x: str): + self._doi_journal = x + + @property + def doi_preprint(self) -> str: + """ + The journal publication (main) DOI associated with the study. + See also `.doi_preprint`. + """ + return self._doi_preprint - @doi.setter - def doi(self, x: Union[str, List[str]]): - self._doi = x + @doi_preprint.setter + def doi_preprint(self, x: str): + self._doi_preprint = x + + @property + def doi(self) -> List[str]: + """ + All publication DOI associated with the study which are the journal publication and the preprint. + See also `.doi_preprint`, `.doi_journal`. + """ + dois = [] + if self.doi_journal is not None: + dois.append(self.doi_journal) + if self.doi_preprint is not None: + dois.append(self.doi_preprint) + return dois @property def doi_main(self) -> str: """ - Yields the main DOI associated with the study, defined as the DOI that comes first in alphabetical order. + The main DOI associated with the study which is the journal publication if available, otherwise the preprint. + See also `.doi_preprint`, `.doi_journal`. """ - return self.doi if isinstance(self.doi, str) else np.sort(self.doi)[0] + return self.doi_preprint if self.doi_journal is None else self.doi_journal @property def directory_formatted_doi(self) -> str: - # Chose first doi in list. - return "d" + "_".join("_".join("_".join(self.doi_main.split("/")).split(".")).split("-")) + return get_directory_formatted_doi(x=self.doi_main) @property def download_url_data(self) -> Union[Tuple[List[str]], Tuple[List[None]]]: @@ -1610,12 +1688,7 @@ def download_url_data(self) -> Union[Tuple[List[str]], Tuple[List[None]]]: Save as tuple with single element, which is a list of all download websites relevant to dataset. :return: """ - if self._download_url_data is not None: - x = self._download_url_data - else: - if self.meta is None: - self.load_meta(fn=None) - x = self.meta[self._adata_ids.download_url_data] + x = self._download_url_data if isinstance(x, str) or x is None: x = [x] if isinstance(x, list): @@ -2135,7 +2208,8 @@ def __crossref_query(self, k): if k == "author": pass return x - except ValueError: + except ValueError as e: + print(f"ValueError: {e}") return None except ConnectionError as e: print(f"ConnectionError: {e}") diff --git a/sfaira/data/base/dataset_group.py b/sfaira/data/base/dataset_group.py index 98f2e2a70..1c1675364 100644 --- a/sfaira/data/base/dataset_group.py +++ b/sfaira/data/base/dataset_group.py @@ -561,11 +561,11 @@ def doi(self) -> List[str]: """ dois = [] for _, v in self.datasets.items(): - vdoi = v.doi + vdoi = v.doi_journal if isinstance(vdoi, str): vdoi = [vdoi] dois.extend(vdoi) - return np.sort(np.unique(vdoi)).tolist() + return np.sort(np.unique(dois)).tolist() @property def supplier(self) -> List[str]: diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py index e89f75cac..47e162a49 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py @@ -18,7 +18,8 @@ def __init__(self, **kwargs): self.author = "Enge" self.disease = "healthy" - self.doi = "10.1016/j.cell.2017.09.004" + self.doi_journal = "10.1016/j.cell.2017.09.004" + self.doi_preprint = "10.1101/108043" self.normalization = "raw" self.assay_sc = "Smart-seq2" self.organ = "islet of Langerhans" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py index 24e8c3103..bc37cd92c 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py @@ -301,7 +301,7 @@ def __init__(self, **kwargs): self.author = "Han" self.dev_stage = sample_dev_stage_dict[self.sample_fn] self.disease = "healthy" - self.doi = "10.1016/j.cell.2018.02.001" + self.doi_journal = "10.1016/j.cell.2018.02.001" self.normalization = "raw" self.organism = "mouse" self.assay_sc = "microwell-seq" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml index 3d3f45314..dda8e81a9 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml @@ -7,9 +7,9 @@ dataset_wise: author: - "Kinchen" default_embedding: - doi: - - "10.1016/j.cell.2018.08.067" - download_url_data: + doi_journal: "10.1016/j.cell.2018.08.067" + doi_preprint: + download_url_data: HC: "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE114374&format=file&file=GSE114374%5FHuman%5FHC%5Fexpression%5Fmatrix%2Etxt%2Egz" UC: "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE114374&format=file&file=GSE114374%5FHuman%5FUC%5Fexpression%5Fmatrix%2Etxt%2Egz" download_url_meta: diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py index 8fcc18add..94eb3decb 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Smilie" self.disease = "healthy" - self.doi = "10.1016/j.cell.2019.06.029" + self.doi_journal = "10.1016/j.cell.2019.06.029" self.normalization = "raw" self.organ = "colonic epithelium" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py index 84bb3cb0f..a0b2f0bec 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Martin" self.disease = "healthy" - self.doi = "10.1016/j.cell.2019.08.008" + self.doi_journal = "10.1016/j.cell.2019.08.008" self.normalization = "raw" self.organ = "ileum" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py index a5ed8b6a4..dce8f3174 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py @@ -21,7 +21,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Henry" self.disease = "healthy" - self.doi = "10.1016/j.celrep.2018.11.086" + self.doi_journal = "10.1016/j.celrep.2018.11.086" self.normalization = "raw" self.sample_source = "primary_tissue" self.state_exact = "healthy" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py index d7e33841b..a8fe3edef 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py @@ -19,7 +19,7 @@ def __init__(self, **kwargs): self.assay_sc = "inDrop" self.author = "Baron" self.disease = "healthy" - self.doi = "10.1016/j.cels.2016.08.011" + self.doi_journal = "10.1016/j.cels.2016.08.011" self.normalization = "raw" self.organ = "pancreas" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py index 6abb1637c..d7dcf8b54 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py @@ -17,7 +17,7 @@ def __init__(self, **kwargs): self.download_url_meta = "https://www.ebi.ac.uk/arrayexpress/files/E-MTAB-5061/E-MTAB-5061.sdrf.txt" self.author = "Segerstolpe" - self.doi = "10.1016/j.cmet.2016.08.020" + self.doi_journal = "10.1016/j.cmet.2016.08.020" self.normalization = "raw" self.organ = "pancreas" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py index d06889f2e..a6821a2af 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): self.download_url_meta = f"private,{self.sample_fn}_annotation.csv" self.author = "Thompson" - self.doi = "10.1016/j.cmet.2019.01.021" + self.doi_journal = "10.1016/j.cmet.2019.01.021" self.normalization = "raw" self.organ = "pancreas" self.organism = "mouse" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py index ea7d78a77..f8634dfbb 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Miller" self.disease = "healthy" - self.doi = "10.1016/j.devcel.2020.01.033" + self.doi_journal = "10.1016/j.devcel.2020.01.033" self.normalization = "raw" self.organ = "lung" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml b/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml index 667f9b4d2..c099579ba 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml @@ -4,9 +4,9 @@ dataset_structure: dataset_wise: author: - "Polioudakis" - doi: - - "10.1016/j.neuron.2019.06.011" - download_url_data: + doi_journal: "10.1016/j.neuron.2019.06.011" + doi_preprint: + download_url_data: - "manual,sc_dev_cortex_geschwind.zip,http://solo.bmap.ucla.edu/shiny/webapp" download_url_meta: normalization: "raw" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py b/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py index 27ef2f609..614c855a5 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.assay_sc = "DroNc-seq" self.author = "Habib" self.disease = "healthy" - self.doi = "10.1038/nmeth.4407" + self.doi_journal = "10.1038/nmeth.4407" self.normalization = "raw" self.organ = "brain" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py index 8fb1f6b59..737dd21c8 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Guo" self.disease = "healthy" - self.doi = "10.1038/s41422-018-0099-2" + self.doi_journal = "10.1038/s41422-018-0099-2" self.normalization = "raw" self.organ = "testis" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py index b581a4338..057da63fd 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "MacParland" self.disease = "healthy" - self.doi = "10.1038/s41467-018-06318-7" + self.doi_journal = "10.1038/s41467-018-06318-7" self.normalization = "raw" self.organ = "caudate lobe of liver" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py index 5bb21b453..5cd22c4da 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py @@ -17,7 +17,7 @@ def __init__(self, **kwargs): self.assay_sc = "DroNc-seq" self.author = "Lake" self.disease = "healthy" - self.doi = "10.1038/s41467-019-10861-2" + self.doi_journal = "10.1038/s41467-019-10861-2" self.normalization = "raw" self.organ = "kidney" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py index 02e672cc9..a05993e6d 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py @@ -60,7 +60,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Szabo" - self.doi = "10.1038/s41467-019-12464-3" + self.doi_journal = "10.1038/s41467-019-12464-3" self.individual = SAMPLE_DICT[self.sample_fn][1] self.normalization = "raw" self.organ = SAMPLE_DICT[self.sample_fn][0] diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py index d1891ed98..86e8bdcf8 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py @@ -14,7 +14,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Menon" self.disease = "healthy" - self.doi = "10.1038/s41467-019-12780-8" + self.doi_journal = "10.1038/s41467-019-12780-8" self.normalization = "raw" self.organ = "retina" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py index 4afa22223..49a3774f2 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py @@ -22,7 +22,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" if self.sample_fn == "E-MTAB-6678.processed" else "Smart-seq2" self.author = "Ventotormo" self.disease = "healthy" - self.doi = "10.1038/s41586-018-0698-6" + self.doi_journal = "10.1038/s41586-018-0698-6" self.normalization = "raw" self.organ = "placenta" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py index ef2597f3b..51ea692a6 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.assay_sc = "CEL-seq2" self.author = "Aizarani" self.disease = "healthy" - self.doi = "10.1038/s41586-019-1373-2" + self.doi_journal = "10.1038/s41586-019-1373-2" self.normalization = "raw" self.sample_source = "primary_tissue" self.organ = "liver" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py index 1c12db6b3..de71d5beb 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Ramachandran" - self.doi = "10.1038/s41586-019-1631-3" + self.doi_journal = "10.1038/s41586-019-1631-3" self.normalization = "raw" self.organ = "liver" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py index 9c1e3efd0..b5c0f4d85 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py @@ -14,7 +14,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Popescu" self.disease = "healthy" - self.doi = "10.1038/s41586-019-1652-y" + self.doi_journal = "10.1038/s41586-019-1652-y" self.normalization = "raw" self.organ = "liver" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml index ced5c2cd8..9b71b382d 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml @@ -4,9 +4,9 @@ dataset_structure: dataset_wise: author: - "Kanton" - doi: - - "10.1038/s41586-019-1654-9" - download_url_data: + doi_journal: "10.1038/s41586-019-1654-9" + doi_preprint: + download_url_data: - "https://www.ebi.ac.uk/arrayexpress/files/E-MTAB-7552/E-MTAB-7552.processed.3.zip" download_url_meta: - "https://www.ebi.ac.uk/arrayexpress/files/E-MTAB-7552/E-MTAB-7552.processed.1.zip" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py index d01646a31..fb6b17063 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py @@ -19,7 +19,7 @@ def __init__(self, **kwargs): ] self.author = "Han" - self.doi = "10.1038/s41586-020-2157-4" + self.doi_journal = "10.1038/s41586-020-2157-4" self.healthy = True self.normalization = "raw" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml index b84ead614..0850a276a 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml @@ -7,9 +7,9 @@ dataset_wise: author: - "Travaglini" default_embedding: "X_tSNE" - doi: - - "10.1038/s41586-020-2922-4" - download_url_data: + doi_journal: "10.1038/s41586-020-2922-4" + doi_preprint: + download_url_data: droplet_normal_lung_blood_scanpy.20200205.RC4.h5ad: "syn21625095,droplet_normal_lung_blood_scanpy.20200205.RC4.h5ad" facs_normal_lung_blood_scanpy.20200205.RC4.h5ad: "syn21625142,facs_normal_lung_blood_scanpy.20200205.RC4.h5ad" download_url_meta: diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py index 5a8d212c6..34ba3268d 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "James" self.disease = "healthy" - self.doi = "10.1038/s41590-020-0602-z" + self.doi_journal = "10.1038/s41590-020-0602-z" self.normalization = "raw" self.organ = "colon" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py index b7ce94249..2981bfda0 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.assay_sc = "Drop-seq" self.author = "Braga" self.disease = "healthy" - self.doi = "10.1038/s41591-019-0468-5" + self.doi_journal = "10.1038/s41591-019-0468-5" self.normalization = "raw" self.organ = "lung" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py index 0a644c56b..c2c655fcd 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py @@ -20,7 +20,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Braga" self.disease = "healthy" - self.doi = "10.1038/s41591-019-0468-5" + self.doi_journal = "10.1038/s41591-019-0468-5" self.normalization = "scaled" self.organ = "bronchus" if self.sample_fn == "vieira19_Bronchi_anonymised.processed.h5ad" else "lung parenchyma" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py index f2c463bcb..7d7465ef8 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py @@ -19,7 +19,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Hove" self.disease = "healthy" - self.doi = "10.1038/s41593-019-0393-4" + self.doi_journal = "10.1038/s41593-019-0393-4" self.normalization = "raw" self.organism = "mouse" self.sample_source = "primary_tissue" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py index e1f35a862..8b0ea9f8f 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py @@ -23,7 +23,7 @@ def __init__(self, **kwargs): self.organism = "human" self.sample_source = "primary_tissue" self.year = 2020 - self.doi = "10.1038/s41597-019-0351-8" + self.doi_journal = "10.1038/s41597-019-0351-8" self.gene_id_symbols_var_key = "names" self.gene_id_ensembl_var_key = "ensembl" diff --git a/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py b/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py index 526111fe5..91b544321 100644 --- a/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Voigt" self.disease = "healthy" - self.doi = "10.1073/pnas.1914143116" + self.doi_journal = "10.1073/pnas.1914143116" self.normalization = "norm" self.organ = "retina" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py b/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py index 5d768b35a..0d03da3c2 100644 --- a/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py @@ -24,7 +24,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Wang" self.disease = "healthy" - self.doi = "10.1084/jem.20191130" + self.doi_journal = "10.1084/jem.20191130" self.normalization = "raw" self.organ = organ self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py b/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py index c5e946f30..900a25af3 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py @@ -21,7 +21,8 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Lukassen" self.disease = "healthy" - self.doi = "10.1101/2020.03.13.991455" + self.doi_journal = "10.15252/embj.20105114" + self.doi_preprint = "10.1101/2020.03.13.991455" self.normalization = "raw" self.organ = "lung" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml index 71ca7fafe..c8860ba9c 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml @@ -5,7 +5,8 @@ dataset_wise: author: - "Hao, Yuhan" default_embedding: - doi: "10.1101/2020.10.12.335331" + doi_journal: "10.1016/j.cell.2021.04.048" + doi_preprint: "10.1101/2020.10.12.335331" download_url_data: "https://atlas.fredhutch.org/nygc/multimodal-pbmc/" download_url_meta: "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE164378&format=file&file=GSE164378%5Fsc%2Emeta%2Edata%5F3P%2Ecsv%2Egz" normalization: "raw" diff --git a/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py b/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py index 178ec629d..f915f5b23 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py @@ -83,7 +83,8 @@ def __init__(self, **kwargs): self.author = "Pisco" self.disease = "healthy" - self.doi = "10.1101/661728" + self.doi_journal = "10.1038/s41586-020-2496-1" + self.doi_preprint = "10.1101/661728" self.normalization = "norm" self.organism = "mouse" self.organ = organ diff --git a/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py b/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py index c4a572f59..5164003b5 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py @@ -20,7 +20,8 @@ def __init__(self, **kwargs): self.download_url_meta = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE135nnn/GSE135893/suppl/GSE135893%5FIPF%5Fmetadata%2Ecsv%2Egz" self.author = "Habermann" - self.doi = "10.1101/753806" + self.doi_journal = "10.1126/sciadv.aba1972" + self.doi_preprint = "10.1101/753806" self.normalization = "raw" self.organ = "lung parenchyma" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py b/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py index 8bc35a905..8233ad6b6 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py @@ -18,7 +18,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Stewart" self.disease = "healthy" - self.doi = "10.1126/science.aat5031" + self.doi_journal = "10.1126/science.aat5031" self.normalization = "norm" self.organ = "kidney" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py b/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py index 7aef677df..6d41a10ae 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Park" self.disease = "healthy" - self.doi = "10.1126/science.aay3224" + self.doi_journal = "10.1126/science.aay3224" self.normalization = "norm" self.organ = "thymus" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml b/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml index cee3e1910..a91663fca 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml @@ -5,8 +5,8 @@ dataset_wise: author: - "Cao" default_embedding: - doi: - - "10.1126/science.aba7721" + doi_journal: "10.1126/science.aba7721" + doi_preprint: download_url_data: "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE156793&format=file&file=GSE156793%5FS3%5Fgene%5Fcount%2Eloom%2Egz" download_url_meta: normalization: "raw" diff --git a/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py b/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py index b97a67c9e..5762b926b 100644 --- a/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py @@ -36,7 +36,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Madissoon" self.disease = "healthy" - self.doi = "10.1186/s13059-019-1906-x" + self.doi_journal = "10.1186/s13059-019-1906-x" self.normalization = "raw" # ToDo "madissoon19_lung.processed.h5ad" is close to integer but not quire (~1e-4) self.organ = "lung parenchyma" if self.sample_fn == "madissoon19_lung.processed.h5ad" else \ "esophagus" if self.sample_fn == "oesophagus.cellxgene.h5ad" else "spleen" diff --git a/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py b/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py index 7b9501d25..3cecb29b4 100644 --- a/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py +++ b/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py @@ -17,7 +17,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Lukowski" self.disease = "healthy" - self.doi = "10.15252/embj.2018100811" + self.doi_journal = "10.15252/embj.2018100811" self.normalization = "raw" self.organ = "retina" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py b/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py index e2aeeb90c..2358175df 100644 --- a/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py +++ b/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py @@ -18,7 +18,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "10x Genomics" self.disease = "healthy" - self.doi = "no_doi_10x_genomics" + self.doi_journal = "no_doi_10x_genomics" self.normalization = "raw" self.organ = "blood" self.organism = "human" diff --git a/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py b/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py index 242d861f0..b7cdb457f 100644 --- a/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py +++ b/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.assay_sc = "10x technology" self.author = "Regev" self.disease = "healthy" - self.doi = "no_doi_regev" + self.doi_journal = "no_doi_regev" self.normalization = "raw" self.organ_obs_key = "derived_organ_parts_label" self.organism = "human" diff --git a/sfaira/data/utils_scripts/streamline_selected.py b/sfaira/data/utils_scripts/streamline_selected.py index b8b3e2ad5..74cfba10e 100644 --- a/sfaira/data/utils_scripts/streamline_selected.py +++ b/sfaira/data/utils_scripts/streamline_selected.py @@ -2,6 +2,8 @@ import sfaira import sys +from sfaira.data import clean_string + # Set global variables. print("sys.argv", sys.argv) @@ -12,13 +14,14 @@ schema = str(sys.argv[5]) dois = str(sys.argv[6]) -path_cache = path_cache if path_cache != "None" else None +path_cache = path_cache if path_cache.lower() != "none" else None +path_meta = path_meta if path_meta.lower() != "none" else None -for x in dois.split(","): +for doi in dois.split(","): ds = sfaira.data.dataloaders.Universe( data_path=data_path, meta_path=path_meta, cache_path=path_cache ) - ds.subset(key="doi", values=[x]) + ds.subset(key="doi", values=[doi]) ds.load( load_raw=False, allow_caching=True, @@ -42,4 +45,7 @@ dsg = ds.dataset_groups[0] for k, v in dsg.datasets.items(): fn = v.doi_cleaned_id + ".h5ad" - v.adata.write_h5ad(os.path.join(path_out, fn)) + dir_name = v.directory_formatted_doi + if not os.path.exists(os.path.join(path_out, dir_name)): + os.makedirs(os.path.join(path_out, dir_name)) + v.adata.write_h5ad(os.path.join(path_out, dir_name, fn)) diff --git a/sfaira/unit_tests/utils.py b/sfaira/unit_tests/utils.py index fa040445e..55667411e 100644 --- a/sfaira/unit_tests/utils.py +++ b/sfaira/unit_tests/utils.py @@ -65,7 +65,7 @@ def cached_store_writing(dir_data, dir_meta, assembly, organism: str = "mouse", ds.subset(key=adata_ids_sfaira.organ, values=[organ]) # Only load files that are not already in cache. anticipated_files = np.unique([ - v.doi[0] if isinstance(v.doi, list) else v.doi for k, v in ds.datasets.items() + v.doi_journal[0] if isinstance(v.doi_journal, list) else v.doi_journal for k, v in ds.datasets.items() if (not os.path.exists(os.path.join(store_path, v.doi_cleaned_id + "." + store_format)) and store_format == "h5ad") or (not os.path.exists(os.path.join(store_path, v.doi_cleaned_id)) and store_format == "dao") From 6f59f7c748eac0544144d64137a42ca511963b45 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Wed, 21 Jul 2021 09:29:46 +0200 Subject: [PATCH 02/15] CLI improvements #321 #314 (#332) * add new adding datasets figure Signed-off-by: zethson * add sample_source Signed-off-by: zethson * add sample_source to validator Signed-off-by: zethson * renamed assay to assay_sc Signed-off-by: zethson * fix assay_sc template Signed-off-by: zethson * add remaining fields Signed-off-by: zethson * add some more documentation Signed-off-by: zethson * add cell_types_original_obs_key Signed-off-by: zethson * add sfaira annotate-dataloader hints Signed-off-by: zethson * fix flake8 Signed-off-by: zethson * remove unnecessary flake8 ignore Signed-off-by: zethson --- .github/workflows/build_package.yml | 2 +- .github/workflows/create_templates.yml | 5 ++-- docs/adding_datasets.rst | 10 +++---- sfaira/commands/create_dataloader.py | 29 +++++++++++++++---- .../multiple_datasets/cookiecutter.json | 6 ++-- .../{{ cookiecutter.id_without_doi }}.yaml | 23 +++++++++++---- .../single_dataset/cookiecutter.json | 6 ++-- .../{{ cookiecutter.id_without_doi }}.yaml | 17 ++++++++--- sfaira/commands/validate_dataloader.py | 25 ++++++++++++---- 9 files changed, 89 insertions(+), 34 deletions(-) diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index ec199d5a6..de6154635 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -31,7 +31,7 @@ jobs: - name: Import sfaira run: python -c "import sfaira" - # Verify that the package does adhere to PyPI's standards + # Verify that the package adheres to PyPI's standards - name: Install required twine packaging dependencies run: pip install setuptools wheel twine diff --git a/.github/workflows/create_templates.yml b/.github/workflows/create_templates.yml index 4e894113f..bcbcbcfa0 100644 --- a/.github/workflows/create_templates.yml +++ b/.github/workflows/create_templates.yml @@ -9,7 +9,6 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: [3.7, 3.8] env: PYTHONIOENCODING: utf-8 @@ -20,7 +19,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v2.1.4 with: - python-version: ${{ matrix.python }} + python-version: 3.8 - name: Upgrade and install pip run: python -m pip install --upgrade pip @@ -31,5 +30,5 @@ jobs: - name: Create single_dataset template run: | cd .. - echo -e "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" | sfaira create-dataloader + echo -e "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" | sfaira create-dataloader rm -rf d10_1000_j_journal_2021_01_001/ diff --git a/docs/adding_datasets.rst b/docs/adding_datasets.rst index e2e297aa2..873dc2fe1 100644 --- a/docs/adding_datasets.rst +++ b/docs/adding_datasets.rst @@ -5,7 +5,7 @@ Adding datasets to sfaira is a great way to increase the visibility of your data This process requires a couple of steps as outlined in the following sections. -.. figure:: https://user-images.githubusercontent.com/21954664/117845386-c6744a00-b280-11eb-9d86-8c47132a3949.png +.. figure:: https://user-images.githubusercontent.com/21954664/126300611-c5ba18b7-7c88-4bb1-8865-a20587cd5f7b.png :alt: sfaira adding datasets Overview of contributing dataloaders to sfaira. First, ensure that your data is not yet available as a dataloader. @@ -185,7 +185,8 @@ before it is loaded into memory: sample_fns: dataset_wise: author: - doi: + doi_preprint: + doi_journal: download_url_data: download_url_meta: normalization: @@ -254,9 +255,8 @@ In summary, a the dataloader for a mouse lung data set could look like this: sample_fns: dataset_wise: author: "me" - doi: - - "my preprint" - - "my peer-reviewed publication" + doi_preprint: "my preprint" + doi_journal: "my journal" download_url_data: "my GEO upload" download_url_meta: normalization: "raw" diff --git a/sfaira/commands/create_dataloader.py b/sfaira/commands/create_dataloader.py index 286a88511..ef674ca84 100644 --- a/sfaira/commands/create_dataloader.py +++ b/sfaira/commands/create_dataloader.py @@ -27,16 +27,19 @@ class TemplateAttributes: download_url_meta: str = '' # download website(s) of meta data files organ: str = '' # (*) organ (anatomical structure) organism: str = '' # (*) species / organism - assay: str = '' # (*, optional) protocol used to sample data (e.g. smart-seq2) + assay_sc: str = '' # (*, optional) protocol used to sample data (e.g. smart-seq2) normalization: str = '' # raw or the used normalization technique default_embedding: str = '' # Default embedding of the data primary_data: str = '' # Is this a primary dataset? disease: str = '' # name of the disease of the condition ethnicity: str = '' # ethnicity of the sample + sample_source: str = '' # source of the sample state_exact: str = '' # state of the sample year: str = 2021 # year in which sample was acquired number_of_datasets: str = 1 # Required to determine the file names + cell_types_original_obs_key: str = '' # Original cell type key in obs + class DataloaderCreator: @@ -81,7 +84,7 @@ def _prompt_dataloader_configuration(self): question='DOI:', default='10.1000/j.journal.2021.01.001') while not re.match(r'\b10\.\d+/[\w.]+\b', doi): - print('[bold red]The entered DOI is malformed!') # noqa: W605 + print('[bold red]The entered DOI is malformed!') doi = sfaira_questionary(function='text', question='DOI:', default='10.1000/j.journal.2021.01.001') @@ -117,9 +120,9 @@ def _prompt_dataloader_configuration(self): self.template_attributes.organ = sfaira_questionary(function='text', question='Organ:', default='NA') - self.template_attributes.assay = sfaira_questionary(function='text', - question='Assay:', - default='NA') + self.template_attributes.assay_sc = sfaira_questionary(function='text', + question='Assay:', + default='NA') self.template_attributes.normalization = sfaira_questionary(function='text', question='Normalization:', default='raw') @@ -129,6 +132,16 @@ def _prompt_dataloader_configuration(self): self.template_attributes.state_exact = sfaira_questionary(function='text', question='Sample state:', default='healthy') + self.template_attributes.sample_source = sfaira_questionary(function='text', + question='Sample source:', + default='NA') + is_cell_type_annotation = sfaira_questionary(function='confirm', + question='Does your dataset have a cell type annotation?', + default='No') + if is_cell_type_annotation: + self.template_attributes.cell_types_original_obs_key = sfaira_questionary(function='text', + question='Cell type annotation obs key:', + default='') self.template_attributes.year = sfaira_questionary(function='text', question='Year:', default='2021') @@ -139,7 +152,7 @@ def _prompt_dataloader_configuration(self): print('[bold yellow] First author was not in the expected format. Using full first author for the id.') first_author_lastname = first_author self.template_attributes.id_without_doi = f'{self.template_attributes.organism}_{self.template_attributes.organ}_' \ - f'{self.template_attributes.year}_{self.template_attributes.assay}_' \ + f'{self.template_attributes.year}_{self.template_attributes.assay_sc}_' \ f'{first_author_lastname}_001' self.template_attributes.id = self.template_attributes.id_without_doi + f'_{self.template_attributes.doi_sfaira_repr}' if self.template_attributes.dataloader_type == 'single_dataset': @@ -152,6 +165,10 @@ def _prompt_dataloader_configuration(self): self.template_attributes.create_extra_description = sfaira_questionary(function='confirm', question='Do you want to add additional custom metadata?', default='Yes') + if is_cell_type_annotation: + print('[bold blue]You will have to run \'sfaira annotate-dataloader\' after the template has been created and filled.') + else: + print('[bold blue]You can skip \'sfaira annotate-dataloader\'.') def _template_attributes_to_dict(self) -> dict: """ diff --git a/sfaira/commands/templates/multiple_datasets/cookiecutter.json b/sfaira/commands/templates/multiple_datasets/cookiecutter.json index ddd7f38ba..fe3d23412 100644 --- a/sfaira/commands/templates/multiple_datasets/cookiecutter.json +++ b/sfaira/commands/templates/multiple_datasets/cookiecutter.json @@ -13,11 +13,13 @@ "normalization": "", "organ": "", "organism": "", - "assay": "", + "assay_sc": "", "year": "", "individual": "", + "sample_source": "", "state_exact": "", "primary_data": "", "default_embedding": "", - "create_extra_description": "" + "create_extra_description": "", + "cell_types_original_obs_key": "" } diff --git a/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml b/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml index 0a10f0d0a..24285b789 100644 --- a/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml +++ b/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml @@ -6,7 +6,8 @@ dataset_structure: author: "{{ cookiecutter.author }}" default_embedding: {% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: "{{ cookiecutter.default_embedding }}" -{% endfor %}doi: "{{ cookiecutter.doi }}" +{% endfor %} doi_preprint: + doi_journal: "{{ cookiecutter.doi }}" download_url_data: {% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: {% endfor %} download_url_meta: @@ -15,12 +16,21 @@ dataset_structure: normalization: "{{ cookiecutter.normalization }}" year: "{{ cookiecutter.year }}" dataset_or_observation_wise: - assay: -{% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: "{{ cookiecutter.assay }}" -{% endfor %} assay_obs_key: + assay_sc: +{% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: "{{ cookiecutter.assay_sc }}" +{% endfor %} assay_sc_obs_key: + assay_differentiation: +{% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: +{% endfor %} assay_differentiation_obs_key: + assay_type_differentiation: +{% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: +{% endfor %} assay_type_differentiation_obs_key: bio_sample: {% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: {% endfor %} bio_sample_obs_key: + cell_line: +{% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: +{% endfor %} cell_line_obs_key: development_stage: {% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: {% endfor %} development_stage_obs_key: @@ -39,6 +49,9 @@ dataset_or_observation_wise: organism: {% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: "{{ cookiecutter.organism }}" {% endfor %} organism_obs_key: + sample_source: +{% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: "{{ cookiecutter.sample_source }}" +{% endfor %} sample_source_obs_key: sex: {% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: {% endfor %} sex_obs_key: @@ -49,7 +62,7 @@ dataset_or_observation_wise: {% for fn in cookiecutter.sample_fns.fns %} {{ fn }}: {% endfor %} tech_sample_obs_key: observation_wise: - cell_types_original_obs_key: + cell_types_original_obs_key: "{{ cookiecutter.cell_types_original_obs_key }}" feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: diff --git a/sfaira/commands/templates/single_dataset/cookiecutter.json b/sfaira/commands/templates/single_dataset/cookiecutter.json index faafc16cd..224f7300a 100644 --- a/sfaira/commands/templates/single_dataset/cookiecutter.json +++ b/sfaira/commands/templates/single_dataset/cookiecutter.json @@ -13,11 +13,13 @@ "normalization": "", "organ": "", "organism": "", - "assay": "", + "assay_sc": "", "year": "", "individual": "", + "sample_source": "", "state_exact": "", "primary_data": "", "default_embedding": "", - "create_extra_description": "" + "create_extra_description": "", + "cell_types_original_obs_key": "" } diff --git a/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml b/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml index 1beb00a25..258606767 100644 --- a/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml +++ b/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml @@ -5,17 +5,24 @@ dataset_structure: dataset_wise: author: "{{ cookiecutter.author }}" default_embedding: "{{ cookiecutter.default_embedding }}" - doi: "{{ cookiecutter.doi }}" + doi_preprint: + doi_journal: "{{ cookiecutter.doi }}" download_url_data: "{{ cookiecutter.download_url_data }}" download_url_meta: "{{ cookiecutter.download_url_meta }}" primary_data: {{ cookiecutter.primary_data }} normalization: "{{ cookiecutter.normalization }}" year: "{{ cookiecutter.year }}" dataset_or_observation_wise: - assay: "{{ cookiecutter.assay }}" - assay_obs_key: + assay_sc: "{{ cookiecutter.assay_sc }}" + assay_sc_obs_key: + assay_differentiation: + assay_differentiation_obs_key: + assay_type_differentiation: + assay_type_differentiation_obs_key: bio_sample: bio_sample_obs_key: + cell_line: + cell_line_obs_key: development_stage: development_stage_obs_key: disease: "{{ cookiecutter.disease }}" @@ -28,6 +35,8 @@ dataset_or_observation_wise: organ_obs_key: organism: "{{ cookiecutter.organism }}" organism_obs_key: + sample_source: "{{ cookiecutter.sample_source }}" + sample_source_obs_key: sex: sex_obs_key: state_exact: "{{ cookiecutter.state_exact }}" @@ -35,7 +44,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: observation_wise: - cell_types_original_obs_key: + cell_types_original_obs_key: "{{ cookiecutter.cell_types_original_obs_key }}" feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: diff --git a/sfaira/commands/validate_dataloader.py b/sfaira/commands/validate_dataloader.py index 54c1af116..19371026b 100644 --- a/sfaira/commands/validate_dataloader.py +++ b/sfaira/commands/validate_dataloader.py @@ -50,22 +50,35 @@ def _validate_required_attributes(self): attributes = ['dataset_structure:sample_fns', 'dataset_wise:author', - 'dataset_wise:doi', + ['dataset_wise:doi_preprint', + 'dataset_wise:doi_journal'], 'dataset_wise:download_url_data', 'dataset_wise:download_url_meta', 'dataset_wise:normalization', 'dataset_wise:year', - 'dataset_or_observation_wise:assay', + 'dataset_or_observation_wise:assay_sc', 'dataset_or_observation_wise:organ', - 'dataset_or_observation_wise:organism'] + 'dataset_or_observation_wise:organism', + 'dataset_or_observation_wise:sample_source', + ['feature_wise:gene_id_ensembl_var_key', + 'feature_wise:gene_id_symbol_var_key']] + # TODO This is some spaghetti which could be more performant with set look ups. flattened_dict = flatten(self.content, reducer=make_reducer(delimiter=':')) for attribute in attributes: try: detected = False - for key in flattened_dict.keys(): - if key.startswith(attribute): - detected = True + for key, val in flattened_dict.items(): + # Lists of attributes are handled in the following way: + # One of the two keys must be present and one of them has to have a value + if isinstance(attribute, list): + for sub_attribute in attribute: + if key.startswith(sub_attribute) and val: + detected = True + # Single string that has to have a value + else: + if key.startswith(attribute) and val: + detected = True if not detected: passed_required_attributes = False self.failed['-1'] = f'Missing attribute: {attribute}' From 7c92d877de3556ad57b9926ba3265b4bbbf78c2c Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Wed, 21 Jul 2021 09:35:52 +0200 Subject: [PATCH 03/15] added lazy ontology loading in OCS (#334) * added lazy ontology loading * updated ontology unit tests to new EFO --- .gitignore | 1 + sfaira/consts/ontologies.py | 60 +++++++++++++++---- sfaira/unit_tests/versions/test_ontologies.py | 16 ++--- 3 files changed, 59 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index ebbb5cd2d..424f7eb03 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ docs/api/ sfaira/unit_tests/test_data_loaders/* sfaira/unit_tests/test_data/* sfaira/unit_tests/template_data/* +sfaira/unit_tests/mock_data/store_* # General patterns: git abuild diff --git a/sfaira/consts/ontologies.py b/sfaira/consts/ontologies.py index 76d9f77ff..7eec586d9 100644 --- a/sfaira/consts/ontologies.py +++ b/sfaira/consts/ontologies.py @@ -1,29 +1,40 @@ +from typing import Dict, Union + from sfaira.versions.metadata import OntologyList, OntologyCl from sfaira.versions.metadata import OntologyCellosaurus, OntologyHsapdv, OntologyMondo, \ OntologyMmusdv, OntologySinglecellLibraryConstruction, OntologyUberon +DEFAULT_CL = "v2021-02-01" + class OntologyContainerSfaira: - _cellontology_class: OntologyCl + """ + The attributes that are relayed via properties, which have a corresponding private attribute "_*", are used to + lazily load these ontologies upon usage and redistribute loading time from package initialisation to actual + usage of ontology. + """ + + _assay_sc: Union[None, OntologySinglecellLibraryConstruction] + _cell_line: Union[None, OntologyCellosaurus] + _cellontology_class: Union[None, OntologyCl] + _development_stage: Union[None, Dict[str, Union[OntologyHsapdv, OntologyMmusdv]]] + _organ: Union[None, OntologyUberon] def __init__(self): self.annotated = OntologyList(terms=[True, False]) self.author = None self.assay_differentiation = None - self.assay_sc = OntologySinglecellLibraryConstruction() + self._assay_sc = None self.assay_type_differentiation = OntologyList(terms=["guided", "unguided"]) self.bio_sample = None - self.cell_line = OntologyCellosaurus() - self.cellontology_class = "v2021-02-01" + self._cell_line = None + self._cellontology_class = None self.cell_types_original = None self.collection_id = None self.default_embedding = None - self.development_stage = { - "human": OntologyHsapdv(), - "mouse": OntologyMmusdv(), - } - self.disease = OntologyMondo() + self._development_stage = None + self._disease = None self.doi_journal = None self.doi_preprint = None self.ethnicity = { @@ -33,7 +44,7 @@ def __init__(self): self.id = None self.individual = None self.normalization = None - self.organ = OntologyUberon() + self._organ = OntologyUberon() self.organism = OntologyList(terms=["mouse", "human"]) # TODO introduce NCBItaxon here self.primary_data = OntologyList(terms=[True, False]) self.sample_source = OntologyList(terms=["primary_tissue", "2d_culture", "3d_culture", "tumor"]) @@ -43,10 +54,39 @@ def __init__(self): self.title = None self.year = OntologyList(terms=list(range(2000, 3000))) + @property + def assay_sc(self): + if self._assay_sc is None: + self._assay_sc = OntologySinglecellLibraryConstruction() + return self._assay_sc + + @property + def cell_line(self): + if self._cell_line is None: + self._cell_line = OntologyCellosaurus() + return self._cell_line + @property def cellontology_class(self): + if self._cellontology_class is None: + self._cellontology_class = OntologyCl(branch=DEFAULT_CL) return self._cellontology_class @cellontology_class.setter def cellontology_class(self, x: str): self._cellontology_class = OntologyCl(branch=x) + + @property + def development_stage(self): + if self._development_stage is None: + self._development_stage = { + "human": OntologyHsapdv(), + "mouse": OntologyMmusdv(), + } + return self._development_stage + + @property + def disease(self): + if self._disease is None: + self._disease = OntologyMondo() + return self._disease diff --git a/sfaira/unit_tests/versions/test_ontologies.py b/sfaira/unit_tests/versions/test_ontologies.py index 9d568919e..301bc66d6 100644 --- a/sfaira/unit_tests/versions/test_ontologies.py +++ b/sfaira/unit_tests/versions/test_ontologies.py @@ -124,8 +124,8 @@ def test_sclc_nodes(): Tests for presence and absence of a few commonly mistaken nodes. """ sclc = OntologySinglecellLibraryConstruction() - assert "10x sequencing" in sclc.node_names - assert "10x 5' v3 sequencing" in sclc.node_names + assert "10x technology" in sclc.node_names + assert "10x 5' v3" in sclc.node_names assert "Smart-like" in sclc.node_names assert "Smart-seq2" in sclc.node_names assert "sci-plex" in sclc.node_names @@ -137,13 +137,13 @@ def test_sclc_is_a(): Tests if is-a relationships work correctly. """ sclc = OntologySinglecellLibraryConstruction() - assert sclc.is_a(query="10x v1 sequencing", reference="10x sequencing") - assert sclc.is_a(query="10x 5' v3 sequencing", reference="10x sequencing") - assert sclc.is_a(query="10x 5' v3 sequencing", reference="10x v3 sequencing") - assert not sclc.is_a(query="10x sequencing", reference="10x v1 sequencing") - assert sclc.is_a(query="10x 5' v3 sequencing", reference="single cell library construction") + assert sclc.is_a(query="10x v1", reference="10x technology") + assert sclc.is_a(query="10x 5' v3", reference="10x technology") + assert sclc.is_a(query="10x 5' v3", reference="10x v3") + assert not sclc.is_a(query="10x technology", reference="10x v1") + assert sclc.is_a(query="10x 5' v3", reference="single cell library construction") assert sclc.is_a(query="sci-plex", reference="single cell library construction") - assert not sclc.is_a(query="sci-plex", reference="10x sequencing") + assert not sclc.is_a(query="sci-plex", reference="10x technology") """ From 7f2b19b17f87230bd50432f0ddb9ba03f8e37527 Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Wed, 21 Jul 2021 11:16:18 +0200 Subject: [PATCH 04/15] added uberon to lazily loaded ontologies in ocs (#335) --- sfaira/consts/ontologies.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sfaira/consts/ontologies.py b/sfaira/consts/ontologies.py index 7eec586d9..a83003939 100644 --- a/sfaira/consts/ontologies.py +++ b/sfaira/consts/ontologies.py @@ -44,7 +44,7 @@ def __init__(self): self.id = None self.individual = None self.normalization = None - self._organ = OntologyUberon() + self._organ = None self.organism = OntologyList(terms=["mouse", "human"]) # TODO introduce NCBItaxon here self.primary_data = OntologyList(terms=[True, False]) self.sample_source = OntologyList(terms=["primary_tissue", "2d_culture", "3d_culture", "tumor"]) @@ -90,3 +90,9 @@ def disease(self): if self._disease is None: self._disease = OntologyMondo() return self._disease + + @property + def organ(self): + if self._organ is None: + self._organ = OntologyUberon() + return self._organ From 9e90dee8b86396135da6baa14b8de11a842ae569 Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Mon, 26 Jul 2021 10:39:22 +0200 Subject: [PATCH 05/15] reassigned gamma cell in pancreas to pancreatic PP cell CL:0002275 (#338) - affects d10_1016_j_cmet_2016_08_020, d10_1016_j_cels_2016_08_011 --- .../human_pancreas_2016_indrop_baron_001.py | 22 ------------------- .../human_pancreas_2016_indrop_baron_001.tsv | 2 +- ...pancreas_2016_smartseq2_segerstolpe_001.py | 1 - ...ancreas_2016_smartseq2_segerstolpe_001.tsv | 2 +- 4 files changed, 2 insertions(+), 25 deletions(-) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py index a8fe3edef..d07a2ada2 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py @@ -7,9 +7,6 @@ class Dataset(DatasetBase): - """ - ToDo: revisit gamma cell missing in CO - """ def __init__(self, **kwargs): super().__init__(**kwargs) @@ -32,25 +29,6 @@ def __init__(self, **kwargs): self.set_dataset_id(idx=1) - self.class_maps = { - "0": { - "t_cell": "T cell", - "quiescent_stellate": "Quiescent Stellate cell", - "mast": "Mast cell", - "delta": "Delta cell", - "beta": "Beta cell", - "endothelial": "Endothelial cell", - "macrophage": "Macrophage", - "epsilon": "Epsilon cell", - "activated_stellate": "Activated Stellate cell", - "acinar": "Acinar cell", - "alpha": "Alpha cell", - "ductal": "Ductal cell", - "schwann": "Schwann cell", - "gamma": "Gamma cell", - }, - } - def load(data_dir, **kwargs): fn = os.path.join(data_dir, "baron16.processed.h5ad") diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.tsv b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.tsv index 0ba8de392..e935becbc 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.tsv +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.tsv @@ -7,7 +7,7 @@ delta pancreatic D cell CL:0000173 ductal pancreatic ductal cell CL:0002079 endothelial endothelial cell CL:0000115 epsilon pancreatic epsilon cell CL:0005019 -gamma pancreatic endocrine cell CL:0008024 +gamma pancreatic PP cell CL:0002275 macrophage macrophage CL:0000235 mast mast cell CL:0000097 quiescent_stellate pancreatic stellate cell CL:0002410 diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py index d7dcf8b54..b9c4c2657 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py @@ -7,7 +7,6 @@ class Dataset(DatasetBase): """ - ToDo: revisit gamma cell missing in CO TODO: move state exact to diesase """ diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.tsv b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.tsv index 8d536b19b..58bc8c851 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.tsv +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.tsv @@ -9,7 +9,7 @@ delta cell pancreatic D cell CL:0000173 ductal cell pancreatic ductal cell CL:0002079 endothelial cell endothelial cell CL:0000115 epsilon cell pancreatic epsilon cell CL:0005019 -gamma cell pancreatic endocrine cell CL:0008024 +gamma cell pancreatic PP cell CL:0002275 mast cell mast cell CL:0000097 unclassified cell UNKNOWN UNKNOWN unclassified endocrine cell pancreatic endocrine cell CL:0008024 From 1580cd34da519670f0d8220c7887c294aa2aa05b Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Mon, 26 Jul 2021 11:50:42 +0200 Subject: [PATCH 06/15] added new edge types (#341) --- sfaira/versions/metadata/base.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/sfaira/versions/metadata/base.py b/sfaira/versions/metadata/base.py index cc9a4132c..73d6fc841 100644 --- a/sfaira/versions/metadata/base.py +++ b/sfaira/versions/metadata/base.py @@ -614,6 +614,7 @@ def __init__( edge_types = [ 'aboral_to', 'adjacent_to', + 'ambiguous_for_taxon', 'anastomoses_with', 'anterior_to', 'anteriorly_connected_to', @@ -621,14 +622,18 @@ def __init__( 'attaches_to_part_of', 'bounding_layer_of', 'branching_part_of', + 'capable_of', + 'capable_of_part_of', 'channel_for', 'channels_from', 'channels_into', 'composed_primarily_of', 'conduit_for', + 'confers_advantage_in', 'connected_to', 'connects', 'contains', + 'contains_process', 'continuous_with', 'contributes_to_morphology_of', 'deep_to', @@ -643,6 +648,7 @@ def __init__( 'distalmost_part_of', 'dorsal_to', 'drains', + 'dubious_for_taxon', 'ends', 'ends_with', 'existence_ends_during', @@ -654,6 +660,7 @@ def __init__( 'existence_starts_with', 'extends_fibers_into', 'filtered_through', + 'functionally_related_to', 'has_boundary', 'has_component', 'has_developmental_contribution_from', @@ -665,6 +672,7 @@ def __init__( 'has_part', 'has_potential_to_develop_into', 'has_potential_to_developmentally_contribute_to', + 'has_quality', 'has_skeleton', 'immediate_transformation_of', 'immediately_anterior_to', @@ -685,10 +693,12 @@ def __init__( 'in_proximal_side_of', 'in_right_side_of', 'in_superficial_part_of', + 'in_taxon', 'in_ventral_side_of', 'indirectly_supplies', 'innervated_by', 'innervates', + 'input_of', 'intersects_midsagittal_plane_of', 'is_a', # term DAG -> include because it connect conceptual tissue groups 'layer_part_of', @@ -696,23 +706,34 @@ def __init__( 'location_of', 'lumen_of', 'luminal_space_of', + 'negatively_regulates', + 'never_in_taxon', + 'occurs_in', + 'only_in_taxon', + 'output_of', 'overlaps', 'part_of', # anatomic DAG -> include because it reflects the anatomic coarseness / hierarchy + 'participates_in', + 'positively_regulates', 'postaxialmost_part_of', 'posterior_to', 'posteriorly_connected_to', 'preaxialmost_part_of', 'preceded_by', 'precedes', + 'present_in_taxon', 'produced_by', 'produces', 'protects', 'proximal_to', 'proximally_connected_to', 'proximalmost_part_of', + 'regulates', 'seeAlso', 'serially_homologous_to', 'sexually_homologous_to', + 'simultaneous_with', + 'site_of', 'skeleton_of', 'starts', 'starts_with', @@ -721,6 +742,7 @@ def __init__( 'supplies', 'surrounded_by', 'surrounds', + 'synapsed_by', 'transformation_of', 'tributary_of', 'trunk_part_of', @@ -728,7 +750,8 @@ def __init__( ] edges_to_delete = [] for i, x in enumerate(self.graph.edges): - assert x[2] in edge_types, x + if x[2] not in edge_types: + print(f"NON-CRITICAL WARNING: uberon edge type {x[2]} not in reference list yet") if x[2] not in [ "develops_from", 'develops_from_part_of', @@ -804,7 +827,8 @@ def __init__( else: edges_allowed = ["is_a"] for i, x in enumerate(self.graph.edges): - assert x[2] in edge_types, x + if x[2] not in edge_types: + print(f"NON-CRITICAL WARNING: cl edge type {x[2]} not in reference list yet") if x[2] not in edges_allowed: edges_to_delete.append((x[0], x[1])) for x in edges_to_delete: From ca029ddf97ee32800d693a38c7f47ead386bc730 Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Mon, 26 Jul 2021 13:17:42 +0200 Subject: [PATCH 07/15] Improve CLI documentation (#320) * improved error reporting in annotate * improved file not found reporting in annotate * update template creation workflow * fix doi promting * update download urls * fix data path handling in CLI * fix disease default in cli * fix test-dataloader [skip ci] * fix CI (#339) Co-authored-by: david.seb.fischer Co-authored-by: le-ander <20015434+le-ander@users.noreply.github.com> Co-authored-by: Lukas Heumos --- .github/workflows/create_templates.yml | 1 - docs/adding_datasets.rst | 79 +++++++-------- docs/consuming_data.rst | 2 +- sfaira/cli.py | 99 +++++++++++++------ sfaira/commands/annotate_dataloader.py | 46 ++++++--- sfaira/commands/clean_dataloader.py | 24 ----- sfaira/commands/create_dataloader.py | 37 ++++--- .../{{ cookiecutter.id_without_doi }}.yaml | 1 - .../{{ cookiecutter.id_without_doi }}.yaml | 1 - sfaira/commands/test_dataloader.py | 81 +++++++++++---- sfaira/commands/validate_dataloader.py | 23 ++++- sfaira/consts/utils.py | 11 +++ sfaira/data/__init__.py | 3 +- sfaira/data/base/__init__.py | 2 +- sfaira/data/base/dataset.py | 15 +-- sfaira/data/base/dataset_group.py | 14 ++- .../human_blood_2020_10x_hao_001.yaml | 4 +- .../data/dataloaders/loaders/super_group.py | 2 +- .../data/utils_scripts/streamline_selected.py | 2 - .../unit_tests/data_contribution/__init__.py | 0 .../data_contribution/test_data_template.py | 64 ------------ 21 files changed, 281 insertions(+), 230 deletions(-) delete mode 100644 sfaira/commands/clean_dataloader.py create mode 100644 sfaira/consts/utils.py delete mode 100644 sfaira/unit_tests/data_contribution/__init__.py delete mode 100644 sfaira/unit_tests/data_contribution/test_data_template.py diff --git a/.github/workflows/create_templates.yml b/.github/workflows/create_templates.yml index bcbcbcfa0..f45eb0aee 100644 --- a/.github/workflows/create_templates.yml +++ b/.github/workflows/create_templates.yml @@ -29,6 +29,5 @@ jobs: - name: Create single_dataset template run: | - cd .. echo -e "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" | sfaira create-dataloader rm -rf d10_1000_j_journal_2021_01_001/ diff --git a/docs/adding_datasets.rst b/docs/adding_datasets.rst index 873dc2fe1..e39411e45 100644 --- a/docs/adding_datasets.rst +++ b/docs/adding_datasets.rst @@ -9,7 +9,7 @@ This process requires a couple of steps as outlined in the following sections. :alt: sfaira adding datasets Overview of contributing dataloaders to sfaira. First, ensure that your data is not yet available as a dataloader. - Next, create a dataloader and validate it. Afterwards, annotate it to finally test it. Finally, submit your dataloader to sfaira. + Next, create a dataloader. Afterwards, validate/annotate it to finally test it. Finally, submit your dataloader to sfaira. sfaira features an interactive way of creating, formatting and testing dataloaders through a command line interface (CLI). The common workflow using the CLI looks as follows: @@ -24,7 +24,7 @@ The common workflow using the CLI looks as follows: preprint and publication DOIs if both are available. We will also mention publication names in issues, you will however not find these in the code. -.. _code: https://github.com/theislab/sfaira/tree/dev +.. _code: https://github.com/theislab/sfaira/tree/dev/sfaira/data/dataloaders/loaders .. _issues: https://github.com/theislab/sfaira/issues 2. Install sfaira. @@ -43,93 +43,88 @@ The common workflow using the CLI looks as follows: 3. Create a new dataloader. When creating a dataloader with ``sfaira create-dataloader`` dataloader specific attributes such as organ, organism and many more are prompted for. - We provide a description of all meta data items at the bottom of this file. + We provide a description of all meta data items at the bottom of this page. If the requested information is not available simply hit enter and continue until done. .. code-block:: # make sure you are in the top-level sfaira directory from step 1 git checkout -b YOUR_BRANCH_NAME # create a new branch for your data loader. - sfaira create-dataloader + sfaira create-dataloader [--doi] [--path_loader] [--path_data] - -The created files are created in the sfaira installation under `sfaira/data/dataloaders/loaders/--DOI-folder--`, +If `--doi` is not provided in the command above, the user will be prompted to enter it in the creation process. +If `--path-loader` is not provided the following default location will be used: `./sfaira/data/dataloaders/loaders/`. +If `--path-data` is not provided, the empty folder for the data files will be created in the following default location: `./sfaira/unit_tests/template_data/`. +The created files are created in the sfaira installation under `/--DOI-folder--`, where the DOI-specific folder starts with `d` and is followed by the DOI in which all special characters are replaced by `_`, below referred to as `--DOI-folder--`: .. code-block:: - ├──sfaira/data/dataloaders/loaders/--DOI-folder-- + ├── /--DOI-folder-- ├── extra_description.txt <- Optional extra description file ├── __init__.py ├── NA_NA_2021_NA_Einstein_001.py <- Contains the load function to load the data ├── NA_NA_2021_NA_Einstein_001.yaml <- Specifies all data loader data + ├── /--DOI-folder-- .. 4. Correct yaml file. - Correct errors in `sfaira/data/dataloaders/loaders/--DOI-folder--/NA_NA_2021_NA_Einstein_001.yaml` file and add + Correct errors in `/--DOI-folder--/NA_NA_2021_NA_Einstein_001.yaml` file and add further attributes you may have forgotten in step 2. This step is optional. 5. Make downloaded data available to sfaira data loader testing. - Identify the raw files as indicated in the dataloader classes and copy them into your directory structure as - required by your data loader. - Note that this should be the exact files that are uploaded to cloud servers such as GEO: - Do not decompress these files ff these files are archives such as zip, tar or gz. + Identify the raw data files as indicated in the dataloader classes and copy them into the datafolder created by + the previous command (`/--DOI-folder--/`). + Note that this should be the exact files that are downloadable from the download URL you provided in the dataloader. + Do not decompress these files if these files are archives such as zip, tar or gz. Instead, navigate the archives directly in the load function (step 5). - Copy the data into `sfaira/unit_tests/template_data/--DOI-folder--/`. + Copy the data into `/--DOI-folder--/`. This folder is masked from git and only serves for temporarily using this data for loader testing. After finishing loader contribution, you can delete this data again without any consequences for your loader. 6. Write load function. - Fill load function in `sfaira/data/dataloaders/loaders/--DOI-folder--NA_NA_2021_NA_Einstein_001.py`. - -7. Validate the dataloader with the CLI. - Next validate the integrity of your dataloader content with ``sfaira validate-dataloader ``. - All tests must pass! If any of the tests fail please revisit your dataloader and add the missing information. - -.. code-block:: + Complete the load function in `/--DOI-folder--/NA_NA_2021_NA_Einstein_001.py`. - # make sure you are in the top-level sfaira directory from step 1 - sfaira validate-dataloader `` -.. - -8. Create cell type annotation if your data set is annotated. +7. Create cell type annotation if your data set is annotated. + This function will run fuzzy string matching between the annotations in the metadata column you provided in the + `cell_types_original_obs_key` attribute of the yaml file and the Cell Ontology Database. Note that this will abort with error if there are bugs in your data loader. .. code-block:: # make sure you are in the top-level sfaira directory from step 1 - # sfaira annotate `` TODO + sfaira annotate-dataloader [--doi] [--path_loader] [--path_data] .. -9. Mitigate automated cell type maps. - Sfaira creates a cell type mapping `.tsv` file in the directory in which your data loaders is located if you - indicated that annotation is present by filling `cell_types_original_obs_key`. - This file is: `NA_NA_2021_NA_Einstein_001.tsv`. +8. Clean up the automated cell type maps. + Sfaira creates suggestions for cell type mapping in a `.tsv` file in the directory in which your data loaders is + located if you indicated that annotation is present by filling `cell_types_original_obs_key`. + This file is: `/--DOI-folder--/NA_NA_2021_NA_Einstein_001.tsv`. This file contains two columns with one row for each unique cell type label. The free text identifiers in the first column "source", and the corresponding ontology term in the second column "target". - You can write this file entirely from scratch. - Sfaira also allows you to generate a first guess of this file using fuzzy string matching - which is automatically executed when you run the template data loader unit test for the first time with you new - loader. - Conflicts are not resolved in this first guess and you have to manually decide which free text field corresponds - to which ontology term in the case of conflicts. - Still, this first guess usually drastically speeds up this annotation harmonization. - Note that you do not have to include the non-human-readable IDs here as they are added later in a fully + After running the `annotate-dataloader` function, you can find a number of suggestions for matching the existing + celltype labels to cell labels from the cell ontology. It is now up to you to pick the best match from the + suggestions and delete all others from the line in the `.tsv` file. In certain cases the string matching might + not give the desired result. In such a case you can manually search the Cell Ontology database for the best + match via the OLS_ web-interface. + Note that you do not have to include the non-human-readable `target_id` here as they are added later in a fully automated fashion. -10. Test data loader. +.. _OLS:https://www.ebi.ac.uk/ols/ontologies/cl + +9. Test data loader. Note that this will abort with error if there are bugs in your data loader. .. code-block:: # make sure you are in the top-level sfaira directory from step 1 - # sfaira test-dataloader `` TODO + sfaira test-dataloader [--doi] [--path_loader] [--path_data] .. -11. Make loader public. +10. Make loader public. You can contribute the data loader to public sfaira as code through a pull request. Note that you can also just keep the data loader in your local installation or keep it in sfaira_extensions if you do not want to make it public. @@ -151,7 +146,7 @@ by `_`, below referred to as `--DOI-folder--`: .. The following sections will first describe the underlying design principles of sfaira dataloaders and -then explain how to interactively create, validate and test dataloaders. +then explain how to interactively create, annotate and test dataloaders. Writing dataloaders diff --git a/docs/consuming_data.rst b/docs/consuming_data.rst index 8037892fc..beb5a2466 100644 --- a/docs/consuming_data.rst +++ b/docs/consuming_data.rst @@ -1,4 +1,4 @@ -Consuming Data +Consuming data =============== .. image:: https://raw.githubusercontent.com/theislab/sfaira/master/resources/images/data_zoo.png diff --git a/sfaira/cli.py b/sfaira/cli.py index b68ee6fc8..b581d77b9 100644 --- a/sfaira/cli.py +++ b/sfaira/cli.py @@ -1,6 +1,7 @@ import logging import os import sys +import re import click import rich @@ -11,7 +12,6 @@ from sfaira.commands.annotate_dataloader import DataloaderAnnotater from sfaira.commands.test_dataloader import DataloaderTester -from sfaira.commands.clean_dataloader import DataloaderCleaner from sfaira.commands.validate_dataloader import DataloaderValidator import sfaira @@ -72,63 +72,98 @@ def sfaira_cli(ctx, verbose, log_file): @sfaira_cli.command() -def create_dataloader() -> None: +@click.option('--path-loader', + default="sfaira/data/dataloaders/loaders/", + type=click.Path(exists=True), + help='Relative path from the current directory to the desired location of the dataloader.' + ) +@click.option('--path-data', + default="sfaira/unit_tests/template_data/", + type=click.Path(exists=False), + help='Relative path from the current directory to the datafiles used by this dataloader.' + ) +@click.option('--doi', type=str, default=None, help="The doi of the paper you would like to create a dataloader for.") +def create_dataloader(path_loader, doi, path_data) -> None: """ Interactively create a new sfaira dataloader. """ - dataloader_creator = DataloaderCreator() - dataloader_creator.create_dataloader() + if doi is None or re.match(r'\b10\.\d+/[\w.]+\b', doi): + dataloader_creator = DataloaderCreator(path_loader, doi) + dataloader_creator.create_dataloader() + dataloader_creator.create_datadir(path_data) + else: + print('[bold red]The supplied DOI is malformed!') # noqa: W605 @sfaira_cli.command() -@click.argument('path', type=click.Path(exists=True)) -def clean_dataloader(path) -> None: - """ - Clean a just written sfaira dataloader to adhere to sfaira's standards. - - PATH to the dataloader script. - """ - dataloader_cleaner = DataloaderCleaner(path) - dataloader_cleaner.clean_dataloader() - - -@sfaira_cli.command() -@click.argument('path', type=click.Path(exists=True)) -def validate_dataloader(path) -> None: +@click.option('--path-loader', + default="sfaira/data/dataloaders/loaders/", + type=click.Path(exists=True), + help='Relative path from the current directory to the desired location of the dataloader.' + ) +@click.option('--doi', type=str, default=None, help="The doi of the paper that the dataloader refers to.") +def validate_dataloader(path_loader, doi) -> None: """ Verifies the dataloader against sfaira's requirements. PATH to the dataloader script. """ - dataloader_validator = DataloaderValidator(path) - dataloader_validator.validate() + if doi is None or re.match(r'\b10\.\d+/[\w.]+\b', doi): + dataloader_validator = DataloaderValidator(path_loader, doi) + dataloader_validator.validate() + else: + print('[bold red]The supplied DOI is malformed!') # noqa: W605 @sfaira_cli.command() -@click.argument('path', type=click.Path(exists=True)) -@click.option('--doi', type=str, default=None) -@click.option('--test-data', type=click.Path(exists=True)) -def annotate_dataloader(path, doi, test_data) -> None: +@click.option('--path-loader', + default="sfaira/data/dataloaders/loaders/", + type=click.Path(exists=True), + help='Relative path from the current directory to the location of the dataloader.' + ) +@click.option('--path-data', + default="sfaira/unit_tests/template_data/", + type=click.Path(exists=True), + help='Relative path from the current directory to the datafiles used by this dataloader.' + ) +@click.option('--doi', type=str, default=None, help="The doi of the paper that the dataloader refers to.") +def annotate_dataloader(path_loader, path_data, doi) -> None: """ Annotates a dataloader. PATH is the absolute path of the root of your sfaira clone. """ - dataloader_annotater = DataloaderAnnotater() - dataloader_annotater.annotate(path, doi, test_data) + if doi is None or re.match(r'\b10\.\d+/[\w.]+\b', doi): + dataloader_validator = DataloaderValidator(path_loader, doi) + dataloader_validator.validate() + dataloader_annotater = DataloaderAnnotater() + dataloader_annotater.annotate(path_loader, path_data, dataloader_validator.doi) + else: + print('[bold red]The supplied DOI is malformed!') # noqa: W605 @sfaira_cli.command() -@click.argument('path', type=click.Path(exists=True)) -@click.option('--test-data', type=click.Path(exists=True)) -@click.option('--doi', type=str, default=None) -def test_dataloader(path, test_data, doi) -> None: +@click.option('--path-loader', + default="sfaira/data/dataloaders/loaders/", + type=click.Path(exists=True), + help='Relative path from the current directory to the location of the dataloader.' + ) +@click.option('--path-data', + default="sfaira/unit_tests/template_data/", + type=click.Path(exists=True), + help='Relative path from the current directory to the datafiles used by this dataloader.' + ) +@click.option('--doi', type=str, default=None, help="The doi of the paper that the dataloader refers to.") +def test_dataloader(path_loader, path_data, doi) -> None: """Runs a dataloader integration test. PATH is the absolute path of the root of your sfaira clone. """ - dataloader_tester = DataloaderTester(path, test_data, doi) - dataloader_tester.test_dataloader() + if doi is None or re.match(r'\b10\.\d+/[\w.]+\b', doi): + dataloader_tester = DataloaderTester(path_loader, path_data, doi) + dataloader_tester.test_dataloader() + else: + print('[bold red]The supplied DOI is malformed!') # noqa: W605 if __name__ == "__main__": diff --git a/sfaira/commands/annotate_dataloader.py b/sfaira/commands/annotate_dataloader.py index 706339707..a8dd11cd0 100644 --- a/sfaira/commands/annotate_dataloader.py +++ b/sfaira/commands/annotate_dataloader.py @@ -1,9 +1,13 @@ import os import pydoc import shutil +import re +from typing import Union from sfaira.data import DatasetGroupDirectoryOriented, DatasetGroup, DatasetBase from sfaira.data.utils import read_yaml +from sfaira.consts.utils import clean_doi +from sfaira.commands.questionary import sfaira_questionary try: import sfaira_extension as sfairae @@ -23,7 +27,7 @@ def __init__(self): self.dir_loader_sfairae = None self.package_source = None - def annotate(self, path: str, doi: str, test_data: str): + def annotate(self, path_loader: str, path_data: str, doi: Union[str, None]): """ Annotates a provided dataloader. @@ -35,9 +39,18 @@ def annotate(self, path: str, doi: str, test_data: str): (Note that columns are separated by ",") You can also manually check maps here: https://www.ebi.ac.uk/ols/ontologies/cl """ - doi_sfaira_repr = f'd{doi.translate({ord(c): "_" for c in r"!@#$%^&*()[]/{};:,.<>?|`~-=_+"})}' + if not doi: + doi = sfaira_questionary(function='text', + question='DOI:', + default='10.1000/j.journal.2021.01.001') + while not re.match(r'\b10\.\d+/[\w.]+\b', doi): + print('[bold red]The entered DOI is malformed!') # noqa: W605 + doi = sfaira_questionary(function='text', + question='DOI:', + default='10.1000/j.journal.2021.01.001') + doi_sfaira_repr = clean_doi(doi) self._setup_loader(doi_sfaira_repr) - self._annotate(test_data, path, doi) + self._annotate(path_data, path_loader, doi, doi_sfaira_repr) def _setup_loader(self, doi_sfaira_repr: str): """ @@ -70,7 +83,7 @@ def _setup_loader(self, doi_sfaira_repr: str): self.meta_path = meta_path self.cache_path = cache_path self.dir_loader = dir_loader - self.dir_loader_sfairae = dir_loader_sfairae + self.dir_loader_sfairae = None if sfairae is None else dir_loader_sfairae self.package_source = package_source def _get_ds(self, test_data: str): @@ -83,27 +96,33 @@ def _get_ds(self, test_data: str): return ds - def buffered_load(self, test_data: str): + def buffered_load(self, test_data: str, doi_sfaira_repr: str): + if not os.path.exists(test_data): + raise ValueError(f"test-data directory {test_data} does not exist.") + if doi_sfaira_repr not in os.listdir(test_data): + raise ValueError(f"did not find data folder named {doi_sfaira_repr} in test-data directory " + f"{test_data}, only found {os.listdir(test_data)}") ds = self._get_ds(test_data=test_data) - # TODO try-except with good error description saying that the data loader is broken here: ds.load( remove_gene_version=False, match_to_reference=None, load_raw=True, # Force raw load so non confound future tests by data loader bugs in previous versions. - allow_caching=True, + allow_caching=False, + verbose=3 ) - - assert len(ds.ids) > 0, f"no data sets loaded, make sure raw data is in {test_data}" + assert len(ds.ids) > 0, f"no data sets loaded, make sure raw data is in {test_data}, "\ + f"found {os.listdir(os.path.join(test_data, doi_sfaira_repr))}" return ds - def _annotate(self, test_data: str, path: str, doi: str): - ds = self.buffered_load(test_data=test_data) + def _annotate(self, test_data: str, path: str, doi: str, doi_sfaira_repr: str): + ds = self.buffered_load(test_data=test_data, doi_sfaira_repr=doi_sfaira_repr) # Create cell type conversion table: cwd = os.path.dirname(self.file_path) dataset_module = str(cwd.split("/")[-1]) # Group data sets by file module: # Note that if we were not grouping the cell type map .tsv files by file module, we could directly call # write_ontology_class_map on the ds. + tsvs_written = [] for f in os.listdir(cwd): if os.path.isfile(os.path.join(cwd, f)): # only files # Narrow down to data set files: @@ -172,8 +191,11 @@ def _annotate(self, test_data: str, path: str, doi: str): # III) Write this directly into the sfaira clone so that it can be committed via git. # TODO any errors not to be caught here? doi_sfaira_repr = f'd{doi.translate({ord(c): "_" for c in r"!@#$%^&*()[]/{};:,.<>?|`~-=_+"})}' + fn_tsv = os.path.join(path, doi_sfaira_repr, f"{file_module}.tsv") dsg_f.write_ontology_class_map( - fn=os.path.join(f"{path}/sfaira/data/dataloaders/loaders/{doi_sfaira_repr}/{file_module}.tsv"), + fn=fn_tsv, protected_writing=True, n_suggest=4, ) + tsvs_written.append(fn_tsv) + print(f"Completed annotation. Wrote {len(tsvs_written)} files:\n" + "\n".join(tsvs_written)) diff --git a/sfaira/commands/clean_dataloader.py b/sfaira/commands/clean_dataloader.py deleted file mode 100644 index 1823c203a..000000000 --- a/sfaira/commands/clean_dataloader.py +++ /dev/null @@ -1,24 +0,0 @@ -import logging - -import yaml -from boltons.iterutils import remap - -log = logging.getLogger(__name__) - - -class DataloaderCleaner: - - def __init__(self, path): - self.path = path - - def clean_dataloader(self) -> None: - """ - Removes unused keys from the yaml file - """ - with open(self.path) as yaml_file: - content = yaml.load(yaml_file, Loader=yaml.FullLoader) - drop_falsey = lambda path, key, value: bool(value) - clean = remap(content, visit=drop_falsey) - - with open(self.path, 'w') as file: - yaml.dump(clean, file) diff --git a/sfaira/commands/create_dataloader.py b/sfaira/commands/create_dataloader.py index ef674ca84..97450ffa6 100644 --- a/sfaira/commands/create_dataloader.py +++ b/sfaira/commands/create_dataloader.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, asdict from typing import Union, Dict +from sfaira.consts.utils import clean_doi, clean_id_str from sfaira.commands.questionary import sfaira_questionary from rich import print from cookiecutter.main import cookiecutter @@ -43,10 +44,12 @@ class TemplateAttributes: class DataloaderCreator: - def __init__(self): + def __init__(self, path_loader, doi): self.WD = os.path.dirname(__file__) self.TEMPLATES_PATH = f'{self.WD}/templates' self.template_attributes = TemplateAttributes() + self.out_path = path_loader + self.doi = doi def create_dataloader(self): """ @@ -80,16 +83,19 @@ def _prompt_dataloader_configuration(self): question='Author(s):', default='Einstein, Albert; Hawking, Stephen') self.template_attributes.author = author.split(';') if ';' in author else author - doi = sfaira_questionary(function='text', - question='DOI:', - default='10.1000/j.journal.2021.01.001') - while not re.match(r'\b10\.\d+/[\w.]+\b', doi): - print('[bold red]The entered DOI is malformed!') + if self.doi: + doi = self.doi + else: doi = sfaira_questionary(function='text', question='DOI:', default='10.1000/j.journal.2021.01.001') + while not re.match(r'\b10\.\d+/[\w.]+\b', doi): + print('[bold red]The entered DOI is malformed!') + doi = sfaira_questionary(function='text', + question='DOI:', + default='10.1000/j.journal.2021.01.001') self.template_attributes.doi = doi - self.template_attributes.doi_sfaira_repr = f'd{doi.translate({ord(c): "_" for c in r"!@#$%^&*()[]/{};:,.<>?|`~-=_+"})}' + self.template_attributes.doi_sfaira_repr = clean_doi(doi) self.template_attributes.number_of_datasets = sfaira_questionary(function='text', question='Number of datasets:', @@ -128,7 +134,7 @@ def _prompt_dataloader_configuration(self): default='raw') self.template_attributes.disease = sfaira_questionary(function='text', question='Disease:', - default='NA') + default='healthy') self.template_attributes.state_exact = sfaira_questionary(function='text', question='Sample state:', default='healthy') @@ -151,10 +157,13 @@ def _prompt_dataloader_configuration(self): except KeyError: print('[bold yellow] First author was not in the expected format. Using full first author for the id.') first_author_lastname = first_author - self.template_attributes.id_without_doi = f'{self.template_attributes.organism}_{self.template_attributes.organ}_' \ - f'{self.template_attributes.year}_{self.template_attributes.assay_sc}_' \ - f'{first_author_lastname}_001' - self.template_attributes.id = self.template_attributes.id_without_doi + f'_{self.template_attributes.doi_sfaira_repr}' + self.template_attributes.id_without_doi = f'{clean_id_str(self.template_attributes.organism)}_' \ + f'{clean_id_str(self.template_attributes.organ)}_' \ + f'{clean_id_str(self.template_attributes.year)}_' \ + f'{clean_id_str(self.template_attributes.assay_sc)}_' \ + f'{clean_id_str(first_author_lastname)}_001' + self.template_attributes.id = f'{self.template_attributes.id_without_doi}_' \ + f'{self.template_attributes.doi_sfaira_repr}' if self.template_attributes.dataloader_type == 'single_dataset': self.template_attributes.download_url_data = sfaira_questionary(function='text', question='URL to download the data', @@ -180,6 +189,10 @@ def _template_attributes_to_dict(self) -> dict: def _create_dataloader_template(self): template_path = f'{self.TEMPLATES_PATH}/{self.template_attributes.dataloader_type}' cookiecutter(f'{template_path}', + output_dir=self.out_path, no_input=True, overwrite_if_exists=True, extra_context=self._template_attributes_to_dict()) + + def create_datadir(self, path_data): + os.makedirs(os.path.join(path_data, self.template_attributes.doi_sfaira_repr)) diff --git a/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml b/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml index 24285b789..caf9c4ebb 100644 --- a/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml +++ b/sfaira/commands/templates/multiple_datasets/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml @@ -66,6 +66,5 @@ observation_wise: feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: -misc: meta: version: "1.0" diff --git a/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml b/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml index 258606767..60f5f5fb3 100644 --- a/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml +++ b/sfaira/commands/templates/single_dataset/{{ cookiecutter.doi_sfaira_repr }}/{{ cookiecutter.id_without_doi }}.yaml @@ -48,6 +48,5 @@ observation_wise: feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: -misc: meta: version: "1.0" diff --git a/sfaira/commands/test_dataloader.py b/sfaira/commands/test_dataloader.py index 16f9f0a58..f2f4303cb 100644 --- a/sfaira/commands/test_dataloader.py +++ b/sfaira/commands/test_dataloader.py @@ -1,9 +1,17 @@ import logging import os -from subprocess import Popen +import shutil +import pydoc from rich import print from sfaira.commands.questionary import sfaira_questionary +from sfaira.consts.utils import clean_doi +from sfaira.data import DatasetGroupDirectoryOriented + +try: + import sfaira_extension as sfairae +except ImportError: + sfairae = None log = logging.getLogger(__name__) @@ -22,33 +30,70 @@ def test_dataloader(self): """ Runs a predefined unit test on a given dataloader. """ - print('[bold blue]Please ensure that your dataloader is in sfaira/dataloaders/loaders/.') if not self.doi: self._prompt_doi() - self.doi_sfaira_repr = f'd{self.doi.translate({ord(c): "_" for c in r"!@#$%^&*()[]/{};:,.<>?|`~-=_+"})}' - self._run_unittest() + self.doi_sfaira_repr = clean_doi(self.doi) + print(f'[bold blue]Please ensure that your dataloader is in sfaira/dataloaders/loaders/{self.doi_sfaira_repr}.') + self._test_dataloader() def _prompt_doi(self): self.doi = sfaira_questionary(function='text', question='Enter your DOI', default='10.1000/j.journal.2021.01.001') - def _run_unittest(self): + def _get_ds(self): + dir_loader_sfaira = "sfaira.data.dataloaders.loaders." + file_path_sfaira = os.path.dirname(str(pydoc.locate(dir_loader_sfaira + "FILE_PATH"))) + + dir_loader_sfairae = "sfaira_extension.data.dataloaders.loaders." if sfairae else None + file_path_sfairae = os.path.dirname(str(pydoc.locate(dir_loader_sfairae + "FILE_PATH"))) if sfairae else None + + # Check if loader name is a directory either in sfaira or sfaira_extension loader collections: + if self.doi_sfaira_repr in os.listdir(file_path_sfaira): + dir_loader = dir_loader_sfaira + "." + self.doi_sfaira_repr + elif file_path_sfairae and self.doi_sfaira_repr in os.listdir(file_path_sfairae): + dir_loader = dir_loader_sfairae + "." + self.doi_sfaira_repr + else: + raise ValueError("data loader not found in sfaira and also not in sfaira_extension") + file_path = str(pydoc.locate(dir_loader + ".FILE_PATH")) + cache_path = None + # Clear dataset cache + shutil.rmtree(cache_path, ignore_errors=True) + + ds = DatasetGroupDirectoryOriented( + file_base=file_path, + data_path=self.test_data, + meta_path=None, + cache_path=None + ) + + return ds, cache_path + + def _test_dataloader(self): """ - Runs the actual integration test by invoking pytest on it. + Tests the dataloader. """ print('[bold blue]Conflicts are not automatically resolved.') - print('[bold blue]Please go back to [bold]https://www.ebi.ac.uk/ols/ontologies/cl[blue] for every mismatch or conflicts ' - 'and add the correct cell ontology class name into the .tsv "target" column.') - - os.chdir(f'{self.path}/sfaira/unit_tests/data_contribution') + print('[bold blue]Please go back to [bold]https://www.ebi.ac.uk/ols/ontologies/cl[blue] for every mismatch or ' + 'conflicts and add the correct cell ontology class name into the .tsv "target" column.') - pytest = Popen(['pytest', 'test_data_template.py', '--doi_sfaira_repr', self.doi_sfaira_repr, '--test_data', self.test_data], - universal_newlines=True, shell=False, close_fds=True) - (pytest_stdout, pytest_stderr) = pytest.communicate() - if pytest_stdout: - print(pytest_stdout) - if pytest_stderr: - print(pytest_stderr) + ds, cache_path = self._get_ds() + ds.clean_ontology_class_map() - os.chdir(self.cwd) + # TODO try-except with good error description saying that the data loader is broken here: + ds.load( + remove_gene_version=True, + # match_to_reference=TODO get organism here, + load_raw=True, + allow_caching=True + ) + # Try loading from cache: + ds, cache_path = self._get_ds() + # TODO try-except with good error description saying that the data loader is broken here: + ds.load( + remove_gene_version=True, + # match_to_reference=TODO get organism here, + load_raw=False, + allow_caching=True + ) + shutil.rmtree(cache_path, ignore_errors=True) diff --git a/sfaira/commands/validate_dataloader.py b/sfaira/commands/validate_dataloader.py index 19371026b..77c053262 100644 --- a/sfaira/commands/validate_dataloader.py +++ b/sfaira/commands/validate_dataloader.py @@ -1,4 +1,6 @@ import logging +import os +import re import rich import yaml @@ -7,13 +9,28 @@ from flatten_dict.reducer import make_reducer from rich.progress import Progress, BarColumn +from sfaira.consts.utils import clean_doi +from sfaira.commands.questionary import sfaira_questionary + log = logging.getLogger(__name__) class DataloaderValidator: - def __init__(self, path='.'): - self.path: str = path + def __init__(self, path_loader, doi): + if not doi: + doi = sfaira_questionary(function='text', + question='DOI:', + default='10.1000/j.journal.2021.01.001') + while not re.match(r'\b10\.\d+/[\w.]+\b', doi): + print('[bold red]The entered DOI is malformed!') # noqa: W605 + doi = sfaira_questionary(function='text', + question='DOI:', + default='10.1000/j.journal.2021.01.001') + self.doi = doi + + loader_filename = [i for i in os.listdir(os.path.join(path_loader, clean_doi(doi))) if str(i).endswith(".yaml")][0] + self.path_loader: str = os.path.join(path_loader, clean_doi(doi), loader_filename) self.content: dict = {} self.passed: dict = {} self.warned: dict = {} @@ -27,7 +44,7 @@ def validate(self) -> None: Statically verifies a yaml dataloader file against a predefined set of rules. Every rule is a function defined in this class, which must be part of this class' linting_functions. """ - with open(self.path) as yaml_file: + with open(self.path_loader) as yaml_file: self.content = yaml.load(yaml_file, Loader=yaml.FullLoader) progress = Progress("[bold green]{task.description}", BarColumn(bar_width=None), diff --git a/sfaira/consts/utils.py b/sfaira/consts/utils.py new file mode 100644 index 000000000..2826c8144 --- /dev/null +++ b/sfaira/consts/utils.py @@ -0,0 +1,11 @@ +import os + + +def clean_doi(doi: str): + return f'd{doi.translate({ord(c): "_" for c in r"!@#$%^&*()[]/{};:,.<>?|`~-=_+"})}' + + +def clean_id_str(s): + if s is not None: + s = s.replace(',', '').replace(' ', '').replace('-', '').replace('_', '').replace("'", '').lower() + return s diff --git a/sfaira/data/__init__.py b/sfaira/data/__init__.py index 11e56506e..5f692fb65 100644 --- a/sfaira/data/__init__.py +++ b/sfaira/data/__init__.py @@ -1,5 +1,4 @@ -from sfaira.data.base import clean_string, DatasetBase, \ - DatasetGroup, DatasetGroupDirectoryOriented, \ +from sfaira.data.base import DatasetBase, DatasetGroup, DatasetGroupDirectoryOriented, \ DatasetSuperGroup, load_store, DistributedStoreBase, DistributedStoreH5ad, DistributedStoreDao from . import dataloaders from .dataloaders import Universe diff --git a/sfaira/data/base/__init__.py b/sfaira/data/base/__init__.py index e5c92edcd..0f9e9339a 100644 --- a/sfaira/data/base/__init__.py +++ b/sfaira/data/base/__init__.py @@ -1,4 +1,4 @@ -from sfaira.data.base.dataset import DatasetBase, clean_string +from sfaira.data.base.dataset import DatasetBase from sfaira.data.base.dataset_group import DatasetGroup, DatasetGroupDirectoryOriented, DatasetSuperGroup from sfaira.data.base.distributed_store import load_store, DistributedStoreBase, DistributedStoreH5ad, \ DistributedStoreDao diff --git a/sfaira/data/base/dataset.py b/sfaira/data/base/dataset.py index 9801e8470..94a625af7 100644 --- a/sfaira/data/base/dataset.py +++ b/sfaira/data/base/dataset.py @@ -23,6 +23,7 @@ from sfaira.consts import AdataIds, AdataIdsCellxgene, AdataIdsSfaira, META_DATA_FIELDS, OCS from sfaira.data.base.io_dao import write_dao from sfaira.data.utils import collapse_matrix, read_yaml +from sfaira.consts.utils import clean_id_str UNS_STRING_META_IN_OBS = "__obs__" @@ -65,12 +66,6 @@ def is_child( raise ValueError(f"did not recognize ontology type {type(ontology)}") -def clean_string(s): - if s is not None: - s = s.replace(',', '').replace(' ', '').replace('-', '').replace('_', '').replace("'", '').lower() - return s - - def get_directory_formatted_doi(x: str) -> str: return "d" + "_".join("_".join("_".join(x.split("/")).split(".")).split("-")) @@ -1434,11 +1429,11 @@ def set_dataset_id( # Note: access private attributes here, e.g. _organism, to avoid loading of content via meta data, which would # invoke call to self.id before it is set. - self.id = f"{clean_string(self._organism)}_" \ - f"{clean_string(self._organ)}_" \ + self.id = f"{clean_id_str(self._organism)}_" \ + f"{clean_id_str(self._organ)}_" \ f"{self._year}_" \ - f"{clean_string(self._assay_sc)}_" \ - f"{clean_string(author)}_" \ + f"{clean_id_str(self._assay_sc)}_" \ + f"{clean_id_str(author)}_" \ f"{idx}_" \ f"{self.doi_main}" diff --git a/sfaira/data/base/dataset_group.py b/sfaira/data/base/dataset_group.py index 1c1675364..33b37f85b 100644 --- a/sfaira/data/base/dataset_group.py +++ b/sfaira/data/base/dataset_group.py @@ -82,6 +82,7 @@ def load( processes: int = 1, func=None, kwargs_func: Union[None, dict] = None, + verbose: int = 0, **kwargs ): """ @@ -103,6 +104,12 @@ def func(dataset, **kwargs_func): # code manipulating dataset and generating output x. return x :param kwargs_func: Kwargs of func. + :param verbose: Verbosity of description of loading failure. + + - 0: no indication of failure + - 1: indication of which data set failed in warning + - 2: 1 with error report in warning + - 3: reportin as in 2 but aborts with OSError """ args = [ load_raw, @@ -132,7 +139,12 @@ def func(dataset, **kwargs_func): x = map_fn(tuple([v] + args)) # Clear data sets that were not successfully loaded because of missing data: if x is not None: - warnings.warn(f"data set {k} not loaded") + if verbose == 1: + warnings.warn(f"data set {k} not loaded") + if verbose == 2: + warnings.warn(f"data set {k} not loaded\nin data set {x[0]}: {x[1]}") + if verbose == 3: + raise OSError(f"data set {k} not loaded\nin data set {x[0]}: {x[1]}") datasets_to_remove.append(k) for k in datasets_to_remove: del self.datasets[k] diff --git a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml index c8860ba9c..4edce9df5 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml @@ -7,8 +7,8 @@ dataset_wise: default_embedding: doi_journal: "10.1016/j.cell.2021.04.048" doi_preprint: "10.1101/2020.10.12.335331" - download_url_data: "https://atlas.fredhutch.org/nygc/multimodal-pbmc/" - download_url_meta: "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE164378&format=file&file=GSE164378%5Fsc%2Emeta%2Edata%5F3P%2Ecsv%2Egz" + download_url_data: "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE164nnn/GSE164378/suppl/GSE164378_RAW.tar" + download_url_meta: "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE164nnn/GSE164378/suppl/GSE164378_sc.meta.data_3P.csv.gz" normalization: "raw" primary_data: year: 2020 diff --git a/sfaira/data/dataloaders/loaders/super_group.py b/sfaira/data/dataloaders/loaders/super_group.py index 456787694..bbec57dfc 100644 --- a/sfaira/data/dataloaders/loaders/super_group.py +++ b/sfaira/data/dataloaders/loaders/super_group.py @@ -33,7 +33,7 @@ def __init__( for f in os.listdir(cwd): if os.path.isdir(os.path.join(cwd, f)): # only directories if f[:len(dir_prefix)] == dir_prefix and f not in dir_exclude: # Narrow down to data set directories - path_dsg = pydoc.locate(f"sfaira.data.dataloaders.loaders.{f}.FILE_PATH") + path_dsg = str(pydoc.locate(f"sfaira.data.dataloaders.loaders.{f}.FILE_PATH")) if path_dsg is not None: dataset_groups.append(DatasetGroupDirectoryOriented( file_base=path_dsg, diff --git a/sfaira/data/utils_scripts/streamline_selected.py b/sfaira/data/utils_scripts/streamline_selected.py index 74cfba10e..27c2ef0ef 100644 --- a/sfaira/data/utils_scripts/streamline_selected.py +++ b/sfaira/data/utils_scripts/streamline_selected.py @@ -2,8 +2,6 @@ import sfaira import sys -from sfaira.data import clean_string - # Set global variables. print("sys.argv", sys.argv) diff --git a/sfaira/unit_tests/data_contribution/__init__.py b/sfaira/unit_tests/data_contribution/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/sfaira/unit_tests/data_contribution/test_data_template.py b/sfaira/unit_tests/data_contribution/test_data_template.py deleted file mode 100644 index e81c007ae..000000000 --- a/sfaira/unit_tests/data_contribution/test_data_template.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -import pydoc -import shutil - -from sfaira.data import DatasetGroupDirectoryOriented - -try: - import sfaira_extension as sfairae -except ImportError: - sfairae = None - - -def _get_ds(doi_sfaira_repr: str, test_data: str): - dir_loader_sfaira = "sfaira.data.dataloaders.loaders." - file_path_sfaira = "/" + "/".join(pydoc.locate(dir_loader_sfaira + "FILE_PATH").split("/")[:-1]) - if sfairae is not None: - dir_loader_sfairae = "sfaira_extension.data.dataloaders.loaders." - file_path_sfairae = "/" + "/".join(pydoc.locate(dir_loader_sfairae + "FILE_PATH").split("/")[:-1]) - else: - file_path_sfairae = None - # Check if loader name is a directory either in sfaira or sfaira_extension loader collections: - if doi_sfaira_repr in os.listdir(file_path_sfaira): - dir_loader = dir_loader_sfaira + "." + doi_sfaira_repr - elif doi_sfaira_repr in os.listdir(file_path_sfairae): - dir_loader = dir_loader_sfairae + "." + doi_sfaira_repr - else: - raise ValueError("data loader not found in sfaira and also not in sfaira_extension") - file_path = pydoc.locate(dir_loader + ".FILE_PATH") - cache_path = None - # Clear dataset cache - shutil.rmtree(cache_path, ignore_errors=True) - - ds = DatasetGroupDirectoryOriented( - file_base=file_path, - data_path=test_data, - meta_path=None, - cache_path=None - ) - - return ds, cache_path - - -def test_load(doi_sfaira_repr: str, test_data: str): - ds, cache_path = _get_ds(doi_sfaira_repr=doi_sfaira_repr, test_data=test_data) - - ds.clean_ontology_class_map() - - # TODO try-except with good error description saying that the data loader is broken here: - ds.load( - remove_gene_version=True, - # match_to_reference=TODO get organism here, - load_raw=True, - allow_caching=True - ) - # Try loading from cache: - ds = _get_ds(doi_sfaira_repr=doi_sfaira_repr, test_data=test_data) - # TODO try-except with good error description saying that the data loader is broken here: - ds.load( - remove_gene_version=True, - # match_to_reference=TODO get organism here, - load_raw=False, - allow_caching=True - ) - shutil.rmtree(cache_path, ignore_errors=True) From 818ca8cbf519e11e2a35bf5facbeca41968e4fcb Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Fri, 6 Aug 2021 18:51:36 +0200 Subject: [PATCH 08/15] Feature/dao improvements (#318) * updated rounding in cellxgene format export warning * updated DOIs to distinguish preprint and journal * fixed issue with ethnicity handling in cellxgene export * reordered obs in cellxgene streamlining * added store benchmark script * added multi-organism store * update doi setting in datasetinteractive * added mock data for unit test * added msle metric * enabled in memory handling of h5ad backed store * added infrastructure for ontology re-caching * fixed all unit tests and optimised run time a bit Co-authored-by: Abdul Moeed Co-authored-by: le-ander <20015434+le-ander@users.noreply.github.com> --- .gitignore | 7 +- docs/distributed_data.rst | 39 + docs/index.rst | 1 + requirements.txt | 2 + sfaira/cli.py | 38 + sfaira/commands/validate_h5ad.py | 40 + sfaira/consts/__init__.py | 2 + sfaira/consts/adata_fields.py | 19 +- sfaira/consts/directories.py | 14 + sfaira/consts/ontologies.py | 18 +- sfaira/consts/utils.py | 30 + sfaira/data/__init__.py | 7 +- sfaira/data/base/__init__.py | 4 - sfaira/data/dataloaders/__init__.py | 1 + sfaira/data/dataloaders/base/__init__.py | 3 + sfaira/data/{ => dataloaders}/base/dataset.py | 137 ++- .../{ => dataloaders}/base/dataset_group.py | 35 +- sfaira/data/dataloaders/base/utils.py | 43 + .../databases/cellxgene/__init__.py | 2 +- .../databases/cellxgene/cellxgene_group.py | 104 ++- .../databases/cellxgene/cellxgene_loader.py | 244 +++++- .../databases/cellxgene/rest_helpers.py | 66 ++ .../data/dataloaders/databases/super_group.py | 20 +- ...fcolon_2019_10xsequencing_kinchen_001.yaml | 2 +- ...pithelium_2019_10xsequencing_smilie_001.py | 3 +- ...man_ileum_2019_10xsequencing_martin_001.py | 2 +- ...stategland_2018_10xsequencing_henry_001.py | 2 +- ...uman_lung_2020_10xsequencing_miller_001.py | 2 +- ...human_testis_2018_10xsequencing_guo_001.py | 2 +- ...liver_2018_10xsequencing_macparland_001.py | 2 +- .../human_x_2019_10xsequencing_szabo_001.py | 2 +- ...man_retina_2019_10xsequencing_menon_001.py | 2 +- .../human_placenta_2018_x_ventotormo_001.py | 2 +- ...ver_2019_10xsequencing_ramachandran_001.py | 2 +- ...an_liver_2019_10xsequencing_popescu_001.py | 2 +- .../human_lung_2020_x_travaglini_001.yaml | 2 +- ...uman_colon_2020_10xsequencing_james_001.py | 12 +- .../human_x_2019_10xsequencing_braga_x.py | 2 +- .../mouse_x_2019_10xsequencing_hove_001.py | 2 +- ...uman_kidney_2020_10xsequencing_liao_001.py | 2 +- ...man_retina_2019_10xsequencing_voigt_001.py | 2 +- .../human_x_2019_10xsequencing_wang_001.py | 2 +- ...an_lung_2020_10xsequencing_lukassen_001.py | 2 +- .../human_blood_2020_10x_hao_001.yaml | 2 +- .../d10_1101_661728/mouse_x_2019_x_pisco_x.py | 2 +- ...nchyma_2020_10xsequencing_habermann_001.py | 23 +- ...n_kidney_2019_10xsequencing_stewart_001.py | 7 +- ...uman_thymus_2020_10xsequencing_park_001.py | 11 + ...uman_x_2019_10xsequencing_madissoon_001.py | 2 +- ..._retina_2019_10xsequencing_lukowski_001.py | 2 +- ...lood_2019_10xsequencing_10xgenomics_001.py | 2 +- .../human_x_2018_10xsequencing_regev_001.py | 7 +- .../data/dataloaders/loaders/super_group.py | 2 +- sfaira/data/dataloaders/super_group.py | 17 +- sfaira/data/interactive/loader.py | 2 +- sfaira/data/store/__init__.py | 3 + sfaira/data/{base => store}/io_dao.py | 18 +- sfaira/data/store/multi_store.py | 338 ++++++++ .../single_store.py} | 779 ++++++++++-------- .../data/utils_scripts/streamline_selected.py | 1 - sfaira/data/utils_scripts/test_store.py | 284 +++++++ sfaira/data/utils_scripts/test_streamlined.sh | 23 + sfaira/data/utils_scripts/write_store.py | 3 +- sfaira/estimators/keras.py | 71 +- sfaira/estimators/metrics.py | 9 + sfaira/train/summaries.py | 34 +- sfaira/train/train_model.py | 25 +- sfaira/ui/model_zoo.py | 19 +- sfaira/ui/user_interface.py | 13 +- sfaira/unit_tests/__init__.py | 1 + .../data/test_clean_celltype_maps.py | 13 - .../{data => data_for_tests}/__init__.py | 0 .../data_for_tests/databases/__init__.py | 1 + .../data_for_tests/databases/consts.py | 2 + .../data_for_tests/databases/utils.py | 28 + .../data_for_tests/loaders/__init__.py | 3 + .../data_for_tests/loaders/consts.py | 5 + .../loaders/loaders/__init__.py | 1 + .../loaders/loaders/dno_doi_mock1/__init__.py | 1 + ...human_lung_2021_10xtechnology_mock1_001.py | 12 + ...uman_lung_2021_10xtechnology_mock1_001.tsv | 3 + ...man_lung_2021_10xtechnology_mock1_001.yaml | 52 ++ .../loaders/loaders/dno_doi_mock2/__init__.py | 1 + ...e_pancreas_2021_10xtechnology_mock2_001.py | 12 + ..._pancreas_2021_10xtechnology_mock2_001.tsv | 4 + ...pancreas_2021_10xtechnology_mock2_001.yaml | 52 ++ .../loaders/loaders/dno_doi_mock3/__init__.py | 1 + ...human_lung_2021_10xtechnology_mock3_001.py | 12 + ...uman_lung_2021_10xtechnology_mock3_001.tsv | 3 + ...man_lung_2021_10xtechnology_mock3_001.yaml | 52 ++ .../loaders/loaders/super_group.py | 60 ++ .../data_for_tests/loaders/utils.py | 84 ++ sfaira/unit_tests/directories.py | 14 + .../test_data/model_lookuptable.csv | 3 - .../__init__.py | 0 .../data}/__init__.py | 0 .../data/test_clean_celltype_maps.py | 8 + .../data/test_data_utils.py | 0 .../tests_by_submodule/data/test_databases.py | 61 ++ .../data/test_dataset.py | 58 +- .../data/test_store.py | 110 ++- .../tests_by_submodule/estimators/__init__.py | 1 + .../estimators/custom.obo | 0 .../estimators/test_estimator.py | 209 ++--- .../trainer}/__init__.py | 0 .../trainer/test_trainer.py | 52 +- .../ui}/__init__.py | 0 .../ui/test_userinterface.py | 3 +- .../{ => tests_by_submodule}/ui/test_zoo.py | 4 - .../tests_by_submodule/versions/__init__.py | 0 .../versions/test_genomes.py | 67 ++ .../versions/test_ontologies.py | 29 +- .../versions/test_universe.py | 10 +- sfaira/unit_tests/utils.py | 96 --- sfaira/unit_tests/versions/test_genomes.py | 40 - sfaira/unit_tests/versions/test_zoo.py | 91 -- sfaira/versions/genomes/__init__.py | 2 + sfaira/versions/{ => genomes}/genomes.py | 111 ++- sfaira/versions/genomes/utils.py | 43 + sfaira/versions/metadata/base.py | 51 +- sfaira/versions/topologies/class_interface.py | 2 +- 121 files changed, 2921 insertions(+), 1173 deletions(-) create mode 100644 docs/distributed_data.rst create mode 100644 sfaira/commands/validate_h5ad.py create mode 100644 sfaira/consts/directories.py delete mode 100644 sfaira/data/base/__init__.py create mode 100644 sfaira/data/dataloaders/base/__init__.py rename sfaira/data/{ => dataloaders}/base/dataset.py (95%) rename sfaira/data/{ => dataloaders}/base/dataset_group.py (98%) create mode 100644 sfaira/data/dataloaders/base/utils.py create mode 100644 sfaira/data/dataloaders/databases/cellxgene/rest_helpers.py create mode 100644 sfaira/data/store/__init__.py rename sfaira/data/{base => store}/io_dao.py (94%) create mode 100644 sfaira/data/store/multi_store.py rename sfaira/data/{base/distributed_store.py => store/single_store.py} (53%) create mode 100644 sfaira/data/utils_scripts/test_store.py create mode 100644 sfaira/data/utils_scripts/test_streamlined.sh delete mode 100644 sfaira/unit_tests/data/test_clean_celltype_maps.py rename sfaira/unit_tests/{data => data_for_tests}/__init__.py (100%) create mode 100644 sfaira/unit_tests/data_for_tests/databases/__init__.py create mode 100644 sfaira/unit_tests/data_for_tests/databases/consts.py create mode 100644 sfaira/unit_tests/data_for_tests/databases/utils.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/__init__.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/consts.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/__init__.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/__init__.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.tsv create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/__init__.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.tsv create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/__init__.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.tsv create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/utils.py create mode 100644 sfaira/unit_tests/directories.py delete mode 100644 sfaira/unit_tests/test_data/model_lookuptable.csv rename sfaira/unit_tests/{estimators => tests_by_submodule}/__init__.py (100%) rename sfaira/unit_tests/{trainer => tests_by_submodule/data}/__init__.py (100%) create mode 100644 sfaira/unit_tests/tests_by_submodule/data/test_clean_celltype_maps.py rename sfaira/unit_tests/{ => tests_by_submodule}/data/test_data_utils.py (100%) create mode 100644 sfaira/unit_tests/tests_by_submodule/data/test_databases.py rename sfaira/unit_tests/{ => tests_by_submodule}/data/test_dataset.py (63%) rename sfaira/unit_tests/{ => tests_by_submodule}/data/test_store.py (58%) create mode 100644 sfaira/unit_tests/tests_by_submodule/estimators/__init__.py rename sfaira/unit_tests/{ => tests_by_submodule}/estimators/custom.obo (100%) rename sfaira/unit_tests/{ => tests_by_submodule}/estimators/test_estimator.py (77%) rename sfaira/unit_tests/{ui => tests_by_submodule/trainer}/__init__.py (100%) rename sfaira/unit_tests/{ => tests_by_submodule}/trainer/test_trainer.py (56%) rename sfaira/unit_tests/{versions => tests_by_submodule/ui}/__init__.py (100%) rename sfaira/unit_tests/{ => tests_by_submodule}/ui/test_userinterface.py (89%) rename sfaira/unit_tests/{ => tests_by_submodule}/ui/test_zoo.py (82%) create mode 100644 sfaira/unit_tests/tests_by_submodule/versions/__init__.py create mode 100644 sfaira/unit_tests/tests_by_submodule/versions/test_genomes.py rename sfaira/unit_tests/{ => tests_by_submodule}/versions/test_ontologies.py (79%) rename sfaira/unit_tests/{ => tests_by_submodule}/versions/test_universe.py (55%) delete mode 100644 sfaira/unit_tests/utils.py delete mode 100644 sfaira/unit_tests/versions/test_genomes.py delete mode 100644 sfaira/unit_tests/versions/test_zoo.py create mode 100644 sfaira/versions/genomes/__init__.py rename sfaira/versions/{ => genomes}/genomes.py (67%) create mode 100644 sfaira/versions/genomes/utils.py diff --git a/.gitignore b/.gitignore index 424f7eb03..55bf44b45 100644 --- a/.gitignore +++ b/.gitignore @@ -5,10 +5,9 @@ cache/ontologies/cl/* docs/api/ # Unit test temporary data: -sfaira/unit_tests/test_data_loaders/* -sfaira/unit_tests/test_data/* -sfaira/unit_tests/template_data/* -sfaira/unit_tests/mock_data/store_* +sfaira/unit_tests/data_for_testing/mock_data/store* +**cache +**temp # General patterns: git abuild diff --git a/docs/distributed_data.rst b/docs/distributed_data.rst new file mode 100644 index 000000000..02b0ca127 --- /dev/null +++ b/docs/distributed_data.rst @@ -0,0 +1,39 @@ +Distributed data +================ + +Sfaira supports usage of distributed data for model training and execution. +The tools are summarized under `sfaira.data.store`. +In contrast to using an instance of AnnData in memory, these tools can be used to use data sets that are saved +in different files (because they come from different studies) flexibly and out-of-core, +which means without loading them into memory. +A general use case is the training of a model on a large set of data sets, subsetted by particular cell-wise meta +data, without creating a merged AnnData instance in memory first. + +Build a distributed data repository +----------------------------------- + +You can use the sfaira dataset API to write streamlined groups of adata instances to a particular disk locaiton that +then is the store directory. +Some of the array backends used for loading stores can read arrays from cloud servers, such as dask. +Therefore, these store directories can also be on cloud servers in some cases. + +Reading from a distributed data repository +------------------------------------------ + +The core use-case is the consumption of data in batches from a python iterator (a "generator"). +In contrast to using the full data matrix, this allows for workflows that never require the full data matrix in memory. +This generators can for example directly be used in tensorflow or pytorch stochastic mini-batch learning pipelines. +The core interface is `sfaira.data.load_store()` which can be used to initialise a store instance that exposes a +generator, for example. +An important concept in store reading is that the data sets are already streamlined on disk, which means that they have +the same feature space for example. + +Distributed access optimised (DAO) store +---------------------------------------- + +The DAO store format is a on-disk representation of single-cell data which is optimised for generator-based access and +distributed access. +In brief, DAO stores optimize memory consumption and data batch access speed. +Right now, we are using zarr and parquet, this may change in the future, we will continue to work on this format using +the project name "dao". +Note that data sets represented as DAO on disk can still be read into AnnData instances in memory if you wish! diff --git a/docs/index.rst b/docs/index.rst index 38c4db65c..858632fec 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,7 @@ Latest additions tutorials adding_datasets consuming_data + distributed_data models ecosystem roadmap diff --git a/requirements.txt b/requirements.txt index a410be433..a74f8efe9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,10 @@ anndata>=0.7.6 crossref_commons +cellxgene-schema dask docutils fuzzywuzzy +IPython loompy matplotlib networkx diff --git a/sfaira/cli.py b/sfaira/cli.py index b581d77b9..cf730b3e0 100644 --- a/sfaira/cli.py +++ b/sfaira/cli.py @@ -13,6 +13,7 @@ from sfaira.commands.test_dataloader import DataloaderTester from sfaira.commands.validate_dataloader import DataloaderValidator +from sfaira.commands.validate_h5ad import H5adValidator import sfaira from sfaira.commands.create_dataloader import DataloaderCreator @@ -166,6 +167,43 @@ def test_dataloader(path_loader, path_data, doi) -> None: print('[bold red]The supplied DOI is malformed!') # noqa: W605 +@sfaira_cli.command() +@click.argument('doi', type=str) +@click.argument('schema', type=str, default=None) +@click.argument('path_out', type=click.Path(exists=True)) +@click.argument('path_data', type=click.Path(exists=True)) +@click.option('--path_cache', type=click.Path(exists=True), default=None) +def export_h5ad(test_h5ad, schema) -> None: + """Creates a collection of streamlined h5ad object for a given DOI. + + doi is the doi(s) to select for export. You can enumerate multiple dois by suppling a string of dois separated by + a comma. + schema is the schema type ("cellxgene",) to use for streamlining. + path_out is the absolute path to save output into. The h5ad files will be in a folder named after the DOI. + path_data is the absolute path to raw data library, ie one folder above the DOI named folder that contains the raw + files necessary for the selected data loader(s). + path_cache is the optional absolute path to cached data library maintained by sfaira. Using such a cache speeds + up loading in sequential runs but is not necessary. + """ + h5ad_tester = H5adValidator(test_h5ad, schema) + h5ad_tester.test_schema() + h5ad_tester.test_numeric_data() + + +@sfaira_cli.command() +@click.argument('test-h5ad', type=click.Path(exists=True)) +@click.option('--schema', type=str, default=None) +def test_h5ad(test_h5ad, schema) -> None: + """Runs a component test on a streamlined h5ad object. + + test-h5ad is the absolute path of the .h5ad file to test. + schema is the schema type ("cellxgene",) to test. + """ + h5ad_tester = H5adValidator(test_h5ad, schema) + h5ad_tester.test_schema() + h5ad_tester.test_numeric_data() + + if __name__ == "__main__": traceback.install() sys.exit(main()) # pragma: no cover diff --git a/sfaira/commands/validate_h5ad.py b/sfaira/commands/validate_h5ad.py new file mode 100644 index 000000000..4194abd60 --- /dev/null +++ b/sfaira/commands/validate_h5ad.py @@ -0,0 +1,40 @@ +import logging + +import anndata +import numpy as np +import scipy.sparse + +log = logging.getLogger(__name__) + + +class H5adValidator: + + def __init__(self, test_h5ad, schema=None): + self.fn_h5ad: str = test_h5ad + if schema not in ["cellxgene"]: + raise ValueError(f"Did not recognize schema {schema}") + self.schema = schema + self._adata = None + + @property + def adata(self): + if self.adata is None: + self._adata = anndata.read_h5ad(filename=self.fn_h5ad) + return self._adata + + def test_schema(self) -> None: + """Verify that object elements match schema definitions.""" + if self.schema == "cellxgene": + from cellxgene_schema import validate + validate.validate(h5ad_path=self.fn_h5ad, shallow=False) + else: + assert False + + def test_numeric_data(self) -> None: + """Verify that numeric matrices match schema definitions.""" + if isinstance(self.adata.X, scipy.sparse.spmatrix): + x = np.unique(np.asarray(self.adata.X.todense())) + else: + x = np.unique(np.asarray(self.adata.X)) + deviation_from_integer = np.minimum(x % 1, 1. - x % 1) + assert np.max(deviation_from_integer) < 1e-6 diff --git a/sfaira/consts/__init__.py b/sfaira/consts/__init__.py index c48140cbe..aaaa7d3e1 100644 --- a/sfaira/consts/__init__.py +++ b/sfaira/consts/__init__.py @@ -1,5 +1,7 @@ from sfaira.consts.adata_fields import AdataIds, AdataIdsSfaira, AdataIdsCellxgene +from sfaira.consts.directories import CACHE_DIR from sfaira.consts.meta_data_files import META_DATA_FIELDS from sfaira.consts.ontologies import OntologyContainerSfaira +from sfaira.consts.utils import clean_cache OCS = OntologyContainerSfaira() diff --git a/sfaira/consts/adata_fields.py b/sfaira/consts/adata_fields.py index 9a99fc047..e38bcdb58 100644 --- a/sfaira/consts/adata_fields.py +++ b/sfaira/consts/adata_fields.py @@ -46,6 +46,7 @@ class AdataIds: obs_keys: List[str] var_keys: List[str] uns_keys: List[str] + batch_keys: List[str] classmap_source_key: str classmap_target_key: str @@ -127,6 +128,8 @@ def __init__(self): self.unknown_metadata_identifier = "unknown" self.unknown_metadata_ontology_id_identifier = "unknown" + self.batch_keys = [self.bio_sample, self.individual, self.tech_sample] + self.obs_keys = [ "assay_sc", "assay_differentiation", @@ -139,6 +142,7 @@ def __init__(self): "development_stage", "disease", "ethnicity", + "id", "individual", "organ", "organism", @@ -159,9 +163,6 @@ def __init__(self): "doi_preprint", "download_url_data", "download_url_meta", - "id", - "mapped_features", - "ncells", "normalization", "primary_data", "title", @@ -169,7 +170,7 @@ def __init__(self): "load_raw", "mapped_features", "remove_gene_version", - ] + ] + [x for x in self.obs_keys if x not in self.batch_keys] class AdataIdsCellxgene(AdataIds): @@ -181,6 +182,7 @@ class AdataIdsCellxgene(AdataIds): def __init__(self): self.assay_sc = "assay" + self.author = None self.cell_types_original = "free_annotation" # TODO "free_annotation" not always given # TODO: -> This will break streamlining though if self.cell_types_original is the same value as self.cellontology_class!! self.cellontology_class = "cell_type" @@ -189,8 +191,8 @@ def __init__(self): self.doi_journal = "publication_doi" self.doi_preprint = "preprint_doi" self.disease = "disease" - self.gene_id_symbols = "gene_symbol" - self.gene_id_ensembl = "ensembl" + self.gene_id_symbols = "index" + self.gene_id_ensembl = None # TODO not yet streamlined self.gene_id_index = self.gene_id_symbols self.id = "id" self.ncells = "ncells" @@ -215,10 +217,7 @@ def __init__(self): self.invalid_metadata_identifier = "na" self.unknown_metadata_ontology_id_identifier = "" - # accepted file names - self.accepted_file_names = [ - "krasnow_lab_human_lung_cell_atlas_smartseq2-2-remixed.h5ad", - ] + self.batch_keys = [] self.obs_keys = [ "assay_sc", diff --git a/sfaira/consts/directories.py b/sfaira/consts/directories.py new file mode 100644 index 000000000..1a863f9ca --- /dev/null +++ b/sfaira/consts/directories.py @@ -0,0 +1,14 @@ +""" +Paths to cache directories used throughout the code. +""" + +import os + +CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "cache") + +CACHE_DIR_DATABASES = os.path.join(CACHE_DIR, "dataset_meta") +CACHE_DIR_DATABASES_CELLXGENE = os.path.join(CACHE_DIR_DATABASES, "cellxgene") + +CACHE_DIR_GENOMES = os.path.join(CACHE_DIR, "genomes") + +CACHE_DIR_ONTOLOGIES = os.path.join(CACHE_DIR, "ontologies") diff --git a/sfaira/consts/ontologies.py b/sfaira/consts/ontologies.py index a83003939..0dc0da996 100644 --- a/sfaira/consts/ontologies.py +++ b/sfaira/consts/ontologies.py @@ -35,10 +35,12 @@ def __init__(self): self.default_embedding = None self._development_stage = None self._disease = None + self.doi = None + self.doi_main = None self.doi_journal = None self.doi_preprint = None self.ethnicity = { - "human": None, # TODO OntologyHancestro + "human": None, "mouse": None, } self.id = None @@ -54,6 +56,20 @@ def __init__(self): self.title = None self.year = OntologyList(terms=list(range(2000, 3000))) + def reload_ontology(self, attr): + kwargs = {"recache": True} + if attr == "assay_sc": + self._assay_sc = OntologySinglecellLibraryConstruction(**kwargs) + elif attr == "cell_line": + self._cell_line = OntologyCellosaurus(**kwargs) + elif attr == "cellontology_class": + self._cellontology_class = OntologyCl(branch=DEFAULT_CL, **kwargs) + elif attr == "disease": + self._disease = OntologyMondo(**kwargs) + elif attr == "organ": + self._organ = OntologyUberon(**kwargs) + return self._assay_sc + @property def assay_sc(self): if self._assay_sc is None: diff --git a/sfaira/consts/utils.py b/sfaira/consts/utils.py index 2826c8144..7a304a92a 100644 --- a/sfaira/consts/utils.py +++ b/sfaira/consts/utils.py @@ -1,4 +1,34 @@ import os +import shutil +from typing import Union + +from sfaira.consts.directories import CACHE_DIR, CACHE_DIR_DATABASES, CACHE_DIR_GENOMES, CACHE_DIR_ONTOLOGIES + + +def clean_cache(cache: Union[None, str] = None): + """ + Utility function to clean cached objects in paths of sfaira installation. + + This can be used to force re-caching or to reduce directory size. + """ + if cache is not None: + cache_dir_dict = { + "all": CACHE_DIR, + "dataset_meta": CACHE_DIR_DATABASES, + "genomes": CACHE_DIR_GENOMES, + "ontologies": CACHE_DIR_ONTOLOGIES, + } + if cache not in cache_dir_dict.keys(): + raise ValueError(f"Did not find cache directory input {cache} in support list: " + f"{list(cache_dir_dict.keys())}") + else: + print(f"cleaning cache {cache} in directory {cache_dir_dict[cache]}") + # Assert that a path within sfaira is selected as a sanity check: + dir_to_delete = cache_dir_dict[cache] + dir_sfaira = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + assert str(dir_to_delete).startswith(dir_sfaira), \ + f"trying to delete outside of sfaira installation: {dir_to_delete}" + shutil.rmtree(dir_to_delete) def clean_doi(doi: str): diff --git a/sfaira/data/__init__.py b/sfaira/data/__init__.py index 5f692fb65..5878a81f9 100644 --- a/sfaira/data/__init__.py +++ b/sfaira/data/__init__.py @@ -1,6 +1,9 @@ -from sfaira.data.base import DatasetBase, DatasetGroup, DatasetGroupDirectoryOriented, \ - DatasetSuperGroup, load_store, DistributedStoreBase, DistributedStoreH5ad, DistributedStoreDao +from sfaira.data.dataloaders.base import DatasetBase, DatasetGroup, DatasetGroupDirectoryOriented, \ + DatasetSuperGroup +from sfaira.data.store import load_store, DistributedStoreSingleFeatureSpace, DistributedStoreMultipleFeatureSpaceBase, \ + DistributedStoresH5ad, DistributedStoresDao from . import dataloaders from .dataloaders import Universe from .interactive import DatasetInteractive +from . import store from . import utils diff --git a/sfaira/data/base/__init__.py b/sfaira/data/base/__init__.py deleted file mode 100644 index 0f9e9339a..000000000 --- a/sfaira/data/base/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from sfaira.data.base.dataset import DatasetBase -from sfaira.data.base.dataset_group import DatasetGroup, DatasetGroupDirectoryOriented, DatasetSuperGroup -from sfaira.data.base.distributed_store import load_store, DistributedStoreBase, DistributedStoreH5ad, \ - DistributedStoreDao diff --git a/sfaira/data/dataloaders/__init__.py b/sfaira/data/dataloaders/__init__.py index 3ecbc2fc5..c1dda2f56 100644 --- a/sfaira/data/dataloaders/__init__.py +++ b/sfaira/data/dataloaders/__init__.py @@ -1,3 +1,4 @@ +from . import base from . import databases from . import loaders from .super_group import Universe diff --git a/sfaira/data/dataloaders/base/__init__.py b/sfaira/data/dataloaders/base/__init__.py new file mode 100644 index 000000000..18b1d0ae0 --- /dev/null +++ b/sfaira/data/dataloaders/base/__init__.py @@ -0,0 +1,3 @@ +from sfaira.data.dataloaders.base.dataset import DatasetBase +from sfaira.data.dataloaders.base.dataset_group import DatasetGroup, DatasetGroupDirectoryOriented, DatasetSuperGroup +from sfaira.data.dataloaders.base.utils import clean_string diff --git a/sfaira/data/base/dataset.py b/sfaira/data/dataloaders/base/dataset.py similarity index 95% rename from sfaira/data/base/dataset.py rename to sfaira/data/dataloaders/base/dataset.py index 94a625af7..cd70e2c2f 100644 --- a/sfaira/data/base/dataset.py +++ b/sfaira/data/dataloaders/base/dataset.py @@ -21,12 +21,11 @@ from sfaira.versions.genomes import GenomeContainer from sfaira.versions.metadata import Ontology, OntologyHierarchical, CelltypeUniverse from sfaira.consts import AdataIds, AdataIdsCellxgene, AdataIdsSfaira, META_DATA_FIELDS, OCS -from sfaira.data.base.io_dao import write_dao +from sfaira.data.store.io_dao import write_dao +from sfaira.data.dataloaders.base.utils import is_child, clean_string, get_directory_formatted_doi from sfaira.data.utils import collapse_matrix, read_yaml from sfaira.consts.utils import clean_id_str -UNS_STRING_META_IN_OBS = "__obs__" - load_doc = \ """ @@ -37,39 +36,6 @@ """ -def is_child( - query, - ontology: Union[Ontology, bool, int, float, str, List[bool], List[int], List[float], List[str]], - ontology_parent=None, -) -> True: - """ - Check whether value is from set of allowed values using ontology. - - :param query: Value to attempt to set, only yield a single value and not a list. - :param ontology: Constraint for values. - Either ontology instance used to constrain entries, or list of allowed values. - :param ontology_parent: If ontology is a DAG, not only check if node is a DAG node but also whether it is a child - of this parent node. - :return: Whether attempted term is sub-term of allowed term in ontology - """ - if ontology_parent is None and ontology is None: - return True - else: - if isinstance(ontology, Ontology): - if ontology_parent is None: - return ontology.is_node(query) - else: - return ontology.is_a(query=query, reference=ontology_parent) - elif ontology is None: - return query == ontology_parent - else: - raise ValueError(f"did not recognize ontology type {type(ontology)}") - - -def get_directory_formatted_doi(x: str) -> str: - return "d" + "_".join("_".join("_".join(x.split("/")).split(".")).split("-")) - - class DatasetBase(abc.ABC): adata: Union[None, anndata.AnnData] class_maps: dict @@ -499,7 +465,7 @@ def _add_missing_featurenames( " dataloader") elif not self.gene_id_symbols_var_key and self.gene_id_ensembl_var_key: # Convert ensembl ids to gene symbols - id_dict = self.genome_container.id_to_names_dict + id_dict = self.genome_container.id_to_symbols_dict ensids = self.adata.var.index if self.gene_id_ensembl_var_key == "index" else self.adata.var[self.gene_id_ensembl_var_key] self.adata.var[gene_id_symbols] = [ id_dict[n.split(".")[0]] if n.split(".")[0] in id_dict.keys() else 'n/a' @@ -508,7 +474,7 @@ def _add_missing_featurenames( self.gene_id_symbols_var_key = gene_id_symbols elif self.gene_id_symbols_var_key and not self.gene_id_ensembl_var_key: # Convert gene symbols to ensembl ids - id_dict = self.genome_container.names_to_id_dict + id_dict = self.genome_container.symbol_to_id_dict id_strip_dict = self.genome_container.strippednames_to_id_dict # Matching gene names to ensembl ids in the following way: if the gene is present in the ensembl dictionary, # match it straight away, if it is not in there we try to match everything in front of the first period in @@ -679,7 +645,6 @@ def streamline_features( def streamline_metadata( self, schema: str = "sfaira", - uns_to_obs: bool = False, clean_obs: bool = True, clean_var: bool = True, clean_uns: bool = True, @@ -693,8 +658,6 @@ def streamline_metadata( :param schema: Export format. - "sfaira" - "cellxgene" - :param uns_to_obs: Whether to move metadata in .uns to .obs to make sure it's not lost when concatenating - multiple objects. Retains .id in .uns. :param clean_obs: Whether to delete non-streamlined fields in .obs, .obsm and .obsp. :param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp. :param clean_uns: Whether to delete non-streamlined fields in .uns. @@ -714,6 +677,7 @@ def streamline_metadata( if hasattr(adata_target_ids, "gene_id_ensembl") and not hasattr(self._adata_ids, "gene_id_ensembl"): raise ValueError(f"Cannot convert this object to schema {schema}, as the currently applied schema does not " f"have an ensembl gene ID annotation. Please run .streamline_features() first.") + experiment_batch_labels = [getattr(self._adata_ids, x) for x in self._adata_ids.batch_keys] # Creating new var annotation var_new = pd.DataFrame() @@ -751,64 +715,69 @@ def streamline_metadata( # Prepare new .uns dict: uns_new = {} for k in adata_target_ids.uns_keys: - val = getattr(self, k) - if val is None and hasattr(self, f"{k}_obs_key"): - val = np.sort(self.adata.obs[getattr(self, f"{k}_obs_key")].values.tolist()) + if hasattr(self, k) and getattr(self, k) is not None: + val = getattr(self, k) + elif hasattr(self, f"{k}_obs_key") and getattr(self, f"{k}_obs_key") is not None: + val = np.sort(np.unique(self.adata.obs[getattr(self, f"{k}_obs_key")].values)).tolist() + else: + val = None while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: # Unpack nested lists/tuples. val = val[0] uns_new[getattr(adata_target_ids, k)] = val # Prepare new .obs dataframe - experiment_batch_labels = ["bio_sample", "individual", "tech_sample"] per_cell_labels = ["cell_types_original", "cellontology_class", "cellontology_id"] obs_new = pd.DataFrame(index=self.adata.obs.index) # Handle non-cell type labels: for k in [x for x in adata_target_ids.obs_keys if x not in per_cell_labels]: - if hasattr(self, f"{k}_obs_key") and getattr(self, f"{k}_obs_key") is not None: - old_col = getattr(self, f"{k}_obs_key") - val = self.adata.obs[old_col].values.tolist() - else: - old_col = None - val = getattr(self, k) - if val is None: - val = self._adata_ids.unknown_metadata_identifier - # Unpack nested lists/tuples: - while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: - val = val[0] - val = [val] * self.adata.n_obs - new_col = getattr(adata_target_ids, k) # Handle batch-annotation columns which can be provided as a combination of columns separated by an asterisk - if old_col is not None and k in experiment_batch_labels and "*" in old_col: + if k in experiment_batch_labels and getattr(self, f"{k}_obs_key") is not None and \ + "*" in getattr(self, f"{k}_obs_key"): + old_cols = getattr(self, f"{k}_obs_key") batch_cols = [] - for batch_col in old_col.split("*"): + for batch_col in old_cols.split("*"): if batch_col in self.adata.obs_keys(): batch_cols.append(batch_col) else: # This should not occur in single data set loaders (see warning below) but can occur in # streamlined data loaders if not all instances of the streamlined data sets have all columns # in .obs set. - print(f"WARNING: attribute {new_col} of data set {self.id} was not found in column {batch_col}") + print(f"WARNING: attribute {batch_col} of data set {self.id} was not found in columns.") # Build a combination label out of all columns used to describe this group. val = [ "_".join([str(xxx) for xxx in xx]) for xx in zip(*[self.adata.obs[batch_col].values.tolist() for batch_col in batch_cols]) ] - # All other .obs fields are interpreted below as provided else: - # Check values for validity: - ontology = getattr(self.ontology_container_sfaira, k) \ - if hasattr(self.ontology_container_sfaira, k) else None - if k == "development_stage": - ontology = ontology[self.organism] - if k == "ethnicity": - ontology = ontology[self.organism] - self._value_protection(attr=new_col, allowed=ontology, attempted=[ - x for x in np.unique(val) - if x not in [ - self._adata_ids.unknown_metadata_identifier, - self._adata_ids.unknown_metadata_ontology_id_identifier, - ] - ]) + if hasattr(self, f"{k}_obs_key") and getattr(self, f"{k}_obs_key") is not None and \ + getattr(self, f"{k}_obs_key") in self.adata.obs.columns: + # Last and-clause to check if this column is included in data sets. This may be violated if data + # is obtained from a database which is not fully streamlined. + old_col = getattr(self, f"{k}_obs_key") + val = self.adata.obs[old_col].values.tolist() + else: + val = getattr(self, k) + if val is None: + val = self._adata_ids.unknown_metadata_identifier + # Unpack nested lists/tuples: + while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: + val = val[0] + val = [val] * self.adata.n_obs + new_col = getattr(adata_target_ids, k) + # Check values for validity: + ontology = getattr(self.ontology_container_sfaira, k) \ + if hasattr(self.ontology_container_sfaira, k) else None + if k == "development_stage": + ontology = ontology[self.organism] + if k == "ethnicity": + ontology = ontology[self.organism] + self._value_protection(attr=new_col, allowed=ontology, attempted=[ + x for x in np.unique(val) + if x not in [ + self._adata_ids.unknown_metadata_identifier, + self._adata_ids.unknown_metadata_ontology_id_identifier, + ] + ]) obs_new[new_col] = val setattr(self, f"{k}_obs_key", new_col) # Set cell types: @@ -878,20 +847,6 @@ def streamline_metadata( if self.adata.uns[k] is None or self.adata.uns[k] == unknown_old: self.adata.uns[k] = unknown_new - # Move all uns annotation to obs columns if requested - if uns_to_obs: - for k, v in self.adata.uns.items(): - if k not in self.adata.obs_keys(): - if v is None: - v = self._adata_ids.unknown_metadata_identifier - # Unpack nested lists/tuples: - while hasattr(v, '__len__') and not isinstance(v, str) and len(v) == 1: - v = v[0] - self.adata.obs[k] = [v for _ in range(self.adata.n_obs)] - # Retain only target uns keys in .uns. - self.adata.uns = dict([(k, v) for k, v in self.adata.uns.items() - if k in [getattr(adata_target_ids, kk) for kk in ["id"]]]) - # Add additional hard-coded description changes for cellxgene schema: if schema == "cellxgene": self.adata.uns["layer_descriptions"] = {"X": "raw"} @@ -2282,7 +2237,7 @@ def _value_protection( else: raise ValueError(f"'{x}' is not a valid entry for {attr}.") else: - raise ValueError(f"allowed of type {type(allowed)} is not a valid entry for {attr}.") + raise ValueError(f"argument allowed of type {type(allowed)} is not a valid entry for {attr}.") # Flatten attempts if only one was made: if len(attempted_clean) == 1: attempted_clean = attempted_clean[0] diff --git a/sfaira/data/base/dataset_group.py b/sfaira/data/dataloaders/base/dataset_group.py similarity index 98% rename from sfaira/data/base/dataset_group.py rename to sfaira/data/dataloaders/base/dataset_group.py index 33b37f85b..7ff2739fe 100644 --- a/sfaira/data/base/dataset_group.py +++ b/sfaira/data/dataloaders/base/dataset_group.py @@ -12,8 +12,9 @@ from typing import Dict, List, Union import warnings -from sfaira.data.base.dataset import is_child, DatasetBase -from sfaira.versions.genomes import GenomeContainer +from sfaira.data.dataloaders.base.dataset import DatasetBase +from sfaira.data.dataloaders.base.utils import is_child +from sfaira.versions.genomes.genomes import GenomeContainer from sfaira.consts import AdataIds, AdataIdsSfaira from sfaira.data.utils import read_yaml @@ -65,10 +66,12 @@ class DatasetGroup: #dsg_humanlung.adata """ datasets: Dict[str, DatasetBase] + _collection_id: str - def __init__(self, datasets: dict): + def __init__(self, datasets: dict, collection_id: str = "default"): self._adata_ids = AdataIdsSfaira() self.datasets = datasets + self._collection_id = collection_id @property def _unknown_celltype_identifiers(self): @@ -154,7 +157,6 @@ def func(dataset, **kwargs_func): def streamline_metadata( self, schema: str = "sfaira", - uns_to_obs: bool = False, clean_obs: bool = True, clean_var: bool = True, clean_uns: bool = True, @@ -167,7 +169,6 @@ def streamline_metadata( :param schema: Export format. - "sfaira" - "cellxgene" - :param uns_to_obs: Whether to move metadata in .uns to .obs to make sure it's not lost when concatenating multiple objects. :param clean_obs: Whether to delete non-streamlined fields in .obs, .obsm and .obsp. :param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp. :param clean_uns: Whether to delete non-streamlined fields in .uns. @@ -177,7 +178,6 @@ def streamline_metadata( for x in self.ids: self.datasets[x].streamline_metadata( schema=schema, - uns_to_obs=uns_to_obs, clean_obs=clean_obs, clean_var=clean_var, clean_uns=clean_uns, @@ -340,6 +340,10 @@ def download(self, **kwargs): def ids(self): return list(self.datasets.keys()) + @property + def collection_id(self): + return self._collection_id + @property def adata_ls(self): adata_ls = [] @@ -618,13 +622,13 @@ def __init__( # Collect all data loaders from files in directory: datasets = [] self._cwd = os.path.dirname(file_base) - dataset_module = str(self._cwd.split("/")[-1]) + collection_id = str(self._cwd.split("/")[-1]) package_source = "sfaira" if str(self._cwd.split("/")[-5]) == "sfaira" else "sfairae" loader_pydoc_path_sfaira = "sfaira.data.dataloaders.loaders." loader_pydoc_path_sfairae = "sfaira_extension.data.dataloaders.loaders." loader_pydoc_path = loader_pydoc_path_sfaira if package_source == "sfaira" else loader_pydoc_path_sfairae if "group.py" in os.listdir(self._cwd): - DatasetGroupFound = pydoc.locate(loader_pydoc_path + dataset_module + ".group.DatasetGroup") + DatasetGroupFound = pydoc.locate(loader_pydoc_path + collection_id + ".group.DatasetGroup") dsg = DatasetGroupFound(data_path=data_path, meta_path=meta_path, cache_path=cache_path) datasets.extend(list(dsg.datasets.values)) else: @@ -634,23 +638,23 @@ def __init__( if f.split(".")[-1] == "py" and f.split(".")[0] not in ["__init__", "base", "group"]: datasets_f = [] file_module = ".".join(f.split(".")[:-1]) - DatasetFound = pydoc.locate(loader_pydoc_path + dataset_module + "." + file_module + ".Dataset") + DatasetFound = pydoc.locate(loader_pydoc_path + collection_id + "." + file_module + ".Dataset") # Load objects from name space: # - load(): Loading function that return anndata instance. # - SAMPLE_FNS: File name list for DatasetBaseGroupLoadingManyFiles - load_func = pydoc.locate(loader_pydoc_path + dataset_module + "." + file_module + ".load") + load_func = pydoc.locate(loader_pydoc_path + collection_id + "." + file_module + ".load") load_func_annotation = \ - pydoc.locate(loader_pydoc_path + dataset_module + "." + file_module + ".LOAD_ANNOTATION") + pydoc.locate(loader_pydoc_path + collection_id + "." + file_module + ".LOAD_ANNOTATION") # Also check sfaira_extension for additional load_func_annotation: if package_source != "sfairae": - load_func_annotation_sfairae = pydoc.locate(loader_pydoc_path_sfairae + dataset_module + + load_func_annotation_sfairae = pydoc.locate(loader_pydoc_path_sfairae + collection_id + "." + file_module + ".LOAD_ANNOTATION") # LOAD_ANNOTATION is a dictionary so we can use update to extend it. if load_func_annotation_sfairae is not None and load_func_annotation is not None: load_func_annotation.update(load_func_annotation_sfairae) elif load_func_annotation_sfairae is not None and load_func_annotation is None: load_func_annotation = load_func_annotation_sfairae - sample_fns = pydoc.locate(loader_pydoc_path + dataset_module + "." + file_module + + sample_fns = pydoc.locate(loader_pydoc_path + collection_id + "." + file_module + ".SAMPLE_FNS") fn_yaml = os.path.join(self._cwd, file_module + ".yaml") fn_yaml = fn_yaml if os.path.exists(fn_yaml) else None @@ -698,7 +702,7 @@ def __init__( datasets.extend(datasets_f) keys = [x.id for x in datasets] - super().__init__(datasets=dict(zip(keys, datasets))) + super().__init__(datasets=dict(zip(keys, datasets)), collection_id=collection_id) def clean_ontology_class_map(self): """ @@ -1158,7 +1162,6 @@ def load_cached_backed(self, fn: PathLike): def streamline_metadata( self, schema: str = "sfaira", - uns_to_obs: bool = False, clean_obs: bool = True, clean_var: bool = True, clean_uns: bool = True, @@ -1171,7 +1174,6 @@ def streamline_metadata( :param schema: Export format. - "sfaira" - "cellxgene" - :param uns_to_obs: Whether to move metadata in .uns to .obs to make sure it's not lost when concatenating multiple objects. :param clean_obs: Whether to delete non-streamlined fields in .obs, .obsm and .obsp. :param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp. :param clean_uns: Whether to delete non-streamlined fields in .uns. @@ -1182,7 +1184,6 @@ def streamline_metadata( for xx in x.ids: x.datasets[xx].streamline_metadata( schema=schema, - uns_to_obs=uns_to_obs, clean_obs=clean_obs, clean_var=clean_var, clean_uns=clean_uns, diff --git a/sfaira/data/dataloaders/base/utils.py b/sfaira/data/dataloaders/base/utils.py new file mode 100644 index 000000000..b1a36975e --- /dev/null +++ b/sfaira/data/dataloaders/base/utils.py @@ -0,0 +1,43 @@ +from typing import List, Union +from sfaira.versions.metadata import Ontology + +UNS_STRING_META_IN_OBS = "__obs__" + + +def is_child( + query, + ontology: Union[Ontology, bool, int, float, str, List[bool], List[int], List[float], List[str]], + ontology_parent=None, +) -> True: + """ + Check whether value is from set of allowed values using ontology. + + :param query: Value to attempt to set, only yield a single value and not a list. + :param ontology: Constraint for values. + Either ontology instance used to constrain entries, or list of allowed values. + :param ontology_parent: If ontology is a DAG, not only check if node is a DAG node but also whether it is a child + of this parent node. + :return: Whether attempted term is sub-term of allowed term in ontology + """ + if ontology_parent is None and ontology is None: + return True + else: + if isinstance(ontology, Ontology): + if ontology_parent is None: + return ontology.is_node(query) + else: + return ontology.is_a(query=query, reference=ontology_parent) + elif ontology is None: + return query == ontology_parent + else: + raise ValueError(f"did not recognize ontology type {type(ontology)}") + + +def clean_string(s): + if s is not None: + s = s.replace(',', '').replace(' ', '').replace('-', '').replace('_', '').replace("'", '').lower() + return s + + +def get_directory_formatted_doi(x: str) -> str: + return "d" + "_".join("_".join("_".join(x.split("/")).split(".")).split("-")) diff --git a/sfaira/data/dataloaders/databases/cellxgene/__init__.py b/sfaira/data/dataloaders/databases/cellxgene/__init__.py index 472d880aa..00c4b307b 100644 --- a/sfaira/data/dataloaders/databases/cellxgene/__init__.py +++ b/sfaira/data/dataloaders/databases/cellxgene/__init__.py @@ -1,2 +1,2 @@ -from sfaira.data.dataloaders.databases.cellxgene.cellxgene_group import DatasetGroupCellxgene +from sfaira.data.dataloaders.databases.cellxgene.cellxgene_group import DatasetSuperGroupCellxgene, DatasetGroupCellxgene from sfaira.data.dataloaders.databases.cellxgene.cellxgene_loader import Dataset diff --git a/sfaira/data/dataloaders/databases/cellxgene/cellxgene_group.py b/sfaira/data/dataloaders/databases/cellxgene/cellxgene_group.py index 4091e5765..a19f6412d 100644 --- a/sfaira/data/dataloaders/databases/cellxgene/cellxgene_group.py +++ b/sfaira/data/dataloaders/databases/cellxgene/cellxgene_group.py @@ -1,26 +1,102 @@ -import os +import datetime +from IPython.display import display, display_javascript, display_html +import json +import pydoc from typing import Union +import uuid -from sfaira.data import DatasetGroup -from sfaira.consts import AdataIdsCellxgene +from sfaira.data.dataloaders.base import DatasetGroup, DatasetSuperGroup -from .cellxgene_loader import Dataset +from sfaira.data.dataloaders.databases.cellxgene.cellxgene_loader import Dataset +from sfaira.data.dataloaders.databases.cellxgene.rest_helpers import get_collection, get_collections class DatasetGroupCellxgene(DatasetGroup): + collection_id: str + def __init__( - self, - data_path: Union[str, None] = None, - meta_path: Union[str, None] = None, - cache_path: Union[str, None] = None + self, + collection_id: str = "default", + data_path: Union[str, None] = None, + meta_path: Union[str, None] = None, + cache_path: Union[str, None] = None, + verbose: int = 0, ): - self._adata_ids_cellxgene = AdataIdsCellxgene() - fn_ls = os.listdir(data_path) - fn_ls = [x for x in fn_ls if x in self._adata_ids_cellxgene.accepted_file_names] + self._collection = None + dataset_ids = [x["id"] for x in get_collection(collection_id=collection_id)['datasets']] + loader_pydoc_path_sfaira = "sfaira.data.dataloaders.databases.cellxgene.cellxgene_loader" + load_func = pydoc.locate(loader_pydoc_path_sfaira + ".load") datasets = [ - Dataset(data_path=data_path, fn=x, meta_path=meta_path, cache_path=cache_path) - for x in fn_ls + Dataset( + collection_id=collection_id, + data_path=data_path, + meta_path=meta_path, + cache_path=cache_path, + load_func=load_func, + sample_fn=x, + sample_fns=dataset_ids, + verbose=verbose, + ) + for x in dataset_ids ] keys = [x.id for x in datasets] - super().__init__(dict(zip(keys, datasets))) + super().__init__(datasets=dict(zip(keys, datasets)), collection_id=collection_id) + + @property + def collection(self): + if self._collection is None: + self._collection = get_collection(collection_id=self.collection_id) + return self._collection + + def show_summary(self): + uuid_session = str(uuid.uuid4()) + display_html('
'.format(uuid_session), raw=True) + display_javascript(""" + require(["https://rawgit.com/caldwell/renderjson/master/renderjson.js"], function() { + document.getElementById('%s').appendChild(renderjson(%s)) + }); + """ % (uuid_session, json.dumps(self.collection)), raw=True) + + +class DatasetSuperGroupCellxgene(DatasetSuperGroup): + + def __init__( + self, + data_path: Union[str, None] = None, + meta_path: Union[str, None] = None, + cache_path: Union[str, None] = None, + verbose: int = 0, + ): + self._collections = None + # Get all collection IDs and instantiate one data set group per collection. + # Note that the collection itself is not passed to DatasetGroupCellxgene but only the ID string. + dataset_groups = [ + DatasetGroupCellxgene( + collection_id=x["id"], + data_path=data_path, + meta_path=meta_path, + cache_path=cache_path, + verbose=verbose, + ) + for x in self.collections + ] + super().__init__(dataset_groups=dataset_groups) + + @property + def collections(self): + """ + Yield all collections available from REST API. + """ + if self._collections is None: + self._collections = get_collections() + return self._collections + + def show_summary(self): + """ + Prints overview of all collections available. + """ + display("There are " + str(len(self.collections)) + " public collections sorting by newest first:") + for collection in sorted(self.collections, key=lambda key: key['created_at'], reverse=True): + display("id: " + collection['id'] + ' created on: ' + + datetime.date.fromtimestamp(collection['created_at']).strftime("%m/%d/%y")) diff --git a/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py b/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py index b4cd6daff..3855f4c56 100644 --- a/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py +++ b/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py @@ -1,34 +1,66 @@ import anndata +from IPython.display import display_javascript, display_html +import json import os -from typing import Union +import pathlib +import pickle +import requests +from typing import List, Union +import uuid -from sfaira.data import DatasetBase +from sfaira.data.dataloaders.base import DatasetBase from sfaira.consts import AdataIdsCellxgene +from sfaira.consts.directories import CACHE_DIR_DATABASES_CELLXGENE +from sfaira.data.dataloaders.databases.cellxgene.rest_helpers import get_collection, get_data +from sfaira.data.dataloaders.databases.cellxgene.rest_helpers import CELLXGENE_PRODUCTION_ENDPOINT, DOWNLOAD_DATASET + + +def cellxgene_fn(dir, dataset_id): + return os.path.join(dir, dataset_id + ".h5ad") class Dataset(DatasetBase): """ This is a dataloader for downloaded h5ad from cellxgene. - :param data_path: - :param meta_path: - :param kwargs: + In contrast to the base class, each instance is coupled to a particular collection_id to allow query. + In the base classes, collection_id are only defined on the group level. """ + collection_id: str + def __init__( self, - data_path: Union[str, None], - fn: str, + collection_id: str = "default", + data_path: Union[str, None] = None, meta_path: Union[str, None] = None, + cache_path: Union[str, None] = None, + load_func=None, + dict_load_func_annotation=None, + yaml_path: Union[str, None] = None, + sample_fn: Union[str, None] = None, + sample_fns: Union[List[str], None] = None, + additional_annotation_key: Union[str, None] = None, + verbose: int = 0, **kwargs ): - super().__init__(data_path=data_path, meta_path=meta_path, **kwargs) + super().__init__( + data_path=data_path, + meta_path=meta_path, + cache_path=cache_path, + load_func=load_func, + sample_fn=sample_fn, + sample_fns=sample_fns, + ) self._adata_ids_cellxgene = AdataIdsCellxgene() - self.fn = fn + self._collection = None + # The h5ad objects from cellxgene follow a particular structure and the following attributes are guaranteed to + # be in place. Note that these point at the anndata instance and will only be available for evaluation after + # download. See below for attributes that are lazily available self.cellontology_class_obs_key = self._adata_ids_cellxgene.cellontology_class self.cellontology_id_obs_key = self._adata_ids_cellxgene.cellontology_id - self.cell_types_original_obs_key = self._adata_ids_cellxgene.cell_types_original + self.cellontology_original_obs_key = self._adata_ids_cellxgene.cell_types_original self.development_stage_obs_key = self._adata_ids_cellxgene.development_stage self.disease_obs_key = self._adata_ids_cellxgene.disease self.ethnicity_obs_key = self._adata_ids_cellxgene.ethnicity @@ -36,29 +68,181 @@ def __init__( self.organ_obs_key = self._adata_ids_cellxgene.organism self.state_exact_obs_key = self._adata_ids_cellxgene.state_exact - self.gene_id_ensembl_var_key = self._adata_ids_cellxgene.gene_id_ensembl self.gene_id_symbols_var_key = self._adata_ids_cellxgene.gene_id_symbols - def _load(self): + self._unknown_celltype_identifiers = self._adata_ids_cellxgene.unknown_celltype_identifier + + self.collection_id = collection_id + self.supplier = "cellxgene" + doi = [x['link_url'] for x in self.collection["links"] if x['link_type'] == 'DOI'] + self.doi_journal = collection_id if len(doi) == 0 else doi[0] # TODO access journal DOI explicitly. + self.id = sample_fn + # Set h5ad download URLs: + download_url_data = [] + for asset in self._collection_dataset['dataset_assets']: + if asset['filetype'].lower() == "h5ad": + download_url_data.append(CELLXGENE_PRODUCTION_ENDPOINT + DOWNLOAD_DATASET + asset['dataset_id'] + + "/asset/" + asset['id']) + self.download_url_data = download_url_data + + # Set dataset-wise attributes based on object preview from REST API (without h5ad download): + # Set organism first so that other terms can access this attribute (e.g. developmental_stage ontology). + reordered_keys = ["organism"] + [x for x in self._adata_ids_cellxgene.dataset_keys if x != "organism"] + for k in reordered_keys: + val = self._collection_dataset[getattr(self._adata_ids_cellxgene, k)] + # Unique label if list is length 1: + # Otherwise do not set property and resort to cell-wise labels. + if isinstance(val, dict) or k == "sex": + val = [val] + v_clean = [] + for v in val: + if k == "sex": + v = v[0] + else: + # Decide if labels are read from name or ontology ID: + if k == "disease" and (v["label"].lower() == "normal" or v["label"].lower() == "healthy"): + # TODO normal state label varies in disease annotation. This can be removed once streamlined. + v = "healthy" + elif k in ["assay_sc", "disease", "organ"] and \ + v["ontology_term_id"] != self._adata_ids_cellxgene.unknown_metadata_ontology_id_identifier: + v = v["ontology_term_id"] + else: + v = v["label"] + # Organ labels contain labels on tissue type also, such as 'UBERON:0001911 (cell culture)'. + if k == "organ": + v = v.split(" ")[0] + if k == "organism": + organism_map = { + "Homo sapiens": "human", + "Mus musculus": "mouse", + } + if v not in organism_map: + raise ValueError(f"value {v} not recognized") + v = organism_map[v] + if v != self._adata_ids_cellxgene.unknown_metadata_ontology_id_identifier and \ + v != self._adata_ids_cellxgene.invalid_metadata_identifier: + v_clean.append(v) + try: + # Set as single element or list if multiple entries are given. + if len(v_clean) == 1: + setattr(self, k, v_clean[0]) + else: + setattr(self, k, v_clean) + except ValueError as e: + if verbose > 0: + print(f"WARNING: {e} in {self.collection_id} and data set {self.id}") + # Add author information. # TODO need to change this to contributor? + setattr(self, "author", "cellxgene") + + @property + def _collection_cache_dir(self): + """ + The cache dir is in a cache directory in the sfaira installation that is excempt from git versioning. + """ + cache_dir_path = pathlib.Path(CACHE_DIR_DATABASES_CELLXGENE) + cache_dir_path.mkdir(parents=True, exist_ok=True) + return CACHE_DIR_DATABASES_CELLXGENE + + @property + def _collection_cache_fn(self): + return os.path.join(self._collection_cache_dir, self.collection_id + ".pickle") + + @property + def collection(self): + if self._collection is None: + # Check if cached: + if os.path.exists(self._collection_cache_fn): + with open(self._collection_cache_fn, "rb") as f: + self._collection = pickle.load(f) + else: + # Download and cache: + self._collection = get_collection(collection_id=self.collection_id) + with open(self._collection_cache_fn, "wb") as f: + pickle.dump(obj=self._collection, file=f) + return self._collection + + @property + def _collection_dataset(self): + return self.collection['datasets'][self._sample_fns.index(self.sample_fn)] + + @property + def directory_formatted_doi(self) -> str: + return self.collection_id + + def load( + self, + remove_gene_version: bool = True, + match_to_reference: Union[str, bool, None] = None, + load_raw: bool = False, + allow_caching: bool = True, + set_metadata: bool = True, + **kwargs + ): + # Invoke load with cellxgene adapted parameters: + # - Never cache as the cellxgene objects already fast to read. + super().load( + remove_gene_version=False, + match_to_reference=match_to_reference, + load_raw=True, + allow_caching=False, + set_metadata=set_metadata, + **kwargs + ) + + def download(self, filetype: str = "h5ad", verbose: int = 0): """ - Note that in contrast to data set specific data loaders, here, the core attributes are only identified from - the data in this function and are not already set in the constructor. These attributes can still be - used through meta data containers after the data was loaded once. - :return: + Only download if file does not already exist. + + :param filetype: File type to download. + + - "h5ad" + - "rds" + - "loom" """ - fn = os.path.join(self.data_dir_base, self.fn) - adata = anndata.read(fn) + counter = 0 + if not os.path.exists(os.path.join(self.data_dir_base, self.directory_formatted_doi)): + os.makedirs(os.path.join(self.data_dir_base, self.directory_formatted_doi)) + for asset in self._collection_dataset['dataset_assets']: + if asset['filetype'].lower() == filetype: + # Only download if file does not already exist: + fn = cellxgene_fn(dir=self.data_dir, dataset_id=self.sample_fn) + if not os.path.isfile(fn): + counter += 1 + assert counter < 2, f"found more than one {filetype} for data set {self.sample_fn}" + url = CELLXGENE_PRODUCTION_ENDPOINT + DOWNLOAD_DATASET + asset['dataset_id'] + "/asset/" + \ + asset['id'] + r = requests.post(url) + r.raise_for_status() + presigned_url = r.json()['presigned_url'] + # Report: + headers = {'range': 'bytes=0-0'} + r1 = requests.get(presigned_url, headers=headers) + if r1.status_code == requests.codes.partial: + if verbose > 0: + print(f"Downloading {r1.headers['Content-Range']} from {r1.headers['Server']}") + # Download: + open(fn, 'wb').write(get_data(presigned_url=presigned_url)) + + def show_summary(self): + uuid_session = str(uuid.uuid4()) + display_html('
'.format(uuid_session), raw=True) + display_javascript(""" + require(["https://rawgit.com/caldwell/renderjson/master/renderjson.js"], function() { + document.getElementById('%s').appendChild(renderjson(%s)) + }); + """ % (uuid_session, json.dumps(self._collection_dataset)), raw=True) + + +def load(data_dir, sample_fn, **kwargs): + """ + Generalised load function for cellxgene-provided data sets. + + This function corresponds to the dataset-wise load() functions defined in standard sfaira data loaders. + """ + fn = cellxgene_fn(dir=data_dir, dataset_id=sample_fn) + adata = anndata.read_h5ad(fn) + if adata.raw is not None: # TODO still need this? adata.X = adata.raw.X - # TODO delete raw? - - self.author = adata.uns[self._adata_ids_cellxgene.author][self._adata_ids_cellxgene.author_names] - self.doi = adata.uns[self._adata_ids_cellxgene.doi] - self.download_url_data = self.download_url_data - self.id = self.id - self.normalization = 'raw' - self.organ = str(self.fn).split("_")[3] # TODO interface this properly - # self.organ = adata.obs["tissue"].values[0] - self.organism = adata.obs[self._adata_ids_cellxgene.organism].values[0] - self.assay_sc = adata.obs[self._adata_ids_cellxgene.assay_sc].values[0] - self.year = adata.uns[self._adata_ids_cellxgene.year] + del adata.raw + return adata diff --git a/sfaira/data/dataloaders/databases/cellxgene/rest_helpers.py b/sfaira/data/dataloaders/databases/cellxgene/rest_helpers.py new file mode 100644 index 000000000..8fd94c13c --- /dev/null +++ b/sfaira/data/dataloaders/databases/cellxgene/rest_helpers.py @@ -0,0 +1,66 @@ +""" +Helper functionalities to interact with cellxgene REST API. +""" + +import requests +from requests.adapters import HTTPAdapter +from requests.packages.urllib3.util.retry import Retry + +CELLXGENE_PRODUCTION_ENDPOINT = 'https://api.cellxgene.cziscience.com' +COLLECTIONS = "/dp/v1/collections/" +DOWNLOAD_DATASET = "/dp/v1/datasets/" +MAX_RETRIES = 5 +TIMEOUT_COLLECTION = 5 +TIMEOUT_DATA = 10 +HTTP_ERROR_LIST = [429, 502, 504] + + +class CustomHTTPAdapter(HTTPAdapter): + def __init__(self, timeout, **kwargs): + self.timeout = timeout + super().__init__(**kwargs) + + def send(self, request, **kwargs): + kwargs["timeout"] = self.timeout + return super().send(request, **kwargs) + + +def rest_api_collection_request(url): + retry_strategy = Retry( + backoff_factor=0, + method_whitelist=["GET"], + status_forcelist=HTTP_ERROR_LIST, + total=MAX_RETRIES, + ) + adapter = CustomHTTPAdapter(timeout=TIMEOUT_COLLECTION, max_retries=retry_strategy) + https = requests.Session() + https.mount("https://", adapter) + r = https.get(url) + r.raise_for_status() + return r.json() + + +def rest_api_data_request(presigned_url): + retry_strategy = Retry( + backoff_factor=0, + method_whitelist=["GET"], + status_forcelist=HTTP_ERROR_LIST, + total=MAX_RETRIES, + ) + adapter = CustomHTTPAdapter(timeout=TIMEOUT_DATA, max_retries=retry_strategy) + https = requests.Session() + https.mount("https://", adapter) + r = https.get(presigned_url) + return r.content + + +def get_collection(collection_id): + return rest_api_collection_request(url=CELLXGENE_PRODUCTION_ENDPOINT + COLLECTIONS + collection_id) + + +def get_collections(): + return rest_api_collection_request(url=CELLXGENE_PRODUCTION_ENDPOINT + COLLECTIONS)['collections'] + + +def get_data(presigned_url): + return rest_api_data_request(presigned_url=presigned_url) diff --git a/sfaira/data/dataloaders/databases/super_group.py b/sfaira/data/dataloaders/databases/super_group.py index 96405ed62..2f3dada96 100644 --- a/sfaira/data/dataloaders/databases/super_group.py +++ b/sfaira/data/dataloaders/databases/super_group.py @@ -1,7 +1,7 @@ from typing import Union -from sfaira.data import DatasetSuperGroup -from sfaira.data.dataloaders.databases.cellxgene import DatasetGroupCellxgene +from sfaira.data.dataloaders.base.dataset_group import DatasetGroup, DatasetSuperGroup +from sfaira.data.dataloaders.databases.cellxgene import DatasetSuperGroupCellxgene class DatasetSuperGroupDatabases(DatasetSuperGroup): @@ -12,11 +12,11 @@ def __init__( meta_path: Union[str, None] = None, cache_path: Union[str, None] = None, ): - dataset_groups = [] - # List all data bases here: - dataset_groups.append(DatasetGroupCellxgene( - data_path=data_path, - meta_path=meta_path, - cache_path=cache_path - )) - super().__init__(dataset_groups=dataset_groups) + dataset_super_groups = [ + DatasetSuperGroupCellxgene( + data_path=data_path, + meta_path=meta_path, + cache_path=cache_path + ), + ] + super().__init__(dataset_groups=dataset_super_groups) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml index dda8e81a9..3a1b866d4 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml @@ -19,7 +19,7 @@ dataset_wise: primary_data: year: 2019 dataset_or_observation_wise: - assay_sc: "10x technology" + assay_sc: "10x 3' v2" assay_sc_obs_key: assay_differentiation: assay_differentiation_obs_key: diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py index 94eb3decb..802854189 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py @@ -13,7 +13,8 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/smillie19_epi.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + # Note: They used used both 10x 3' v2 and 10x 3' v3. + self.assay_sc = "10x 3' transcription profiling" self.author = "Smilie" self.disease = "healthy" self.doi_journal = "10.1016/j.cell.2019.06.029" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py index a0b2f0bec..c3aff84f9 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py @@ -13,7 +13,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/martin19.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Martin" self.disease = "healthy" self.doi_journal = "10.1016/j.cell.2019.08.008" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py index dce8f3174..7a46ff91e 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py @@ -18,7 +18,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/henry18_0.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Henry" self.disease = "healthy" self.doi_journal = "10.1016/j.celrep.2018.11.086" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py index f8634dfbb..7aff4b8f9 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py @@ -13,7 +13,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/miller20.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Miller" self.disease = "healthy" self.doi_journal = "10.1016/j.devcel.2020.01.033" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py index 737dd21c8..fda7a7ea1 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py @@ -13,7 +13,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/guo18_donor.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Guo" self.disease = "healthy" self.doi_journal = "10.1038/s41422-018-0099-2" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py index 057da63fd..9039758bd 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py @@ -12,7 +12,7 @@ def __init__(self, **kwargs): self.download_url_data = "private,GSE115469.csv.gz" self.download_url_meta = "private,GSE115469_labels.txt" - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "MacParland" self.disease = "healthy" self.doi_journal = "10.1038/s41467-018-06318-7" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py index a05993e6d..87a40002c 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py @@ -58,7 +58,7 @@ def __init__(self, **kwargs): "private,donor2.annotation.txt" ] - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' transcription profiling" self.author = "Szabo" self.doi_journal = "10.1038/s41467-019-12464-3" self.individual = SAMPLE_DICT[self.sample_fn][1] diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py index 86e8bdcf8..f09696431 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py @@ -11,7 +11,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/menon19.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v3" self.author = "Menon" self.disease = "healthy" self.doi_journal = "10.1038/s41467-019-12780-8" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py index 49a3774f2..9af36b696 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py @@ -19,7 +19,7 @@ def __init__(self, **kwargs): self.download_url_meta = f"https://www.ebi.ac.uk/arrayexpress/files/{self.sample_fn.split('.')[0]}/" \ f"{self.sample_fn}.2.zip" - self.assay_sc = "10x technology" if self.sample_fn == "E-MTAB-6678.processed" else "Smart-seq2" + self.assay_sc = "10x 3' v2" if self.sample_fn == "E-MTAB-6678.processed" else "Smart-seq2" self.author = "Ventotormo" self.disease = "healthy" self.doi_journal = "10.1038/s41586-018-0698-6" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py index de71d5beb..eac9db452 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py @@ -14,7 +14,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3433/tissue.rdata" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Ramachandran" self.doi_journal = "10.1038/s41586-019-1631-3" self.normalization = "raw" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py index b5c0f4d85..65f6d5ac7 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py @@ -11,7 +11,7 @@ def __init__(self, **kwargs): self.download_url_data = "private,fetal_liver_alladata_.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Popescu" self.disease = "healthy" self.doi_journal = "10.1038/s41586-019-1652-y" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml index 0850a276a..8fc3e343b 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml @@ -18,7 +18,7 @@ dataset_wise: year: 2020 dataset_or_observation_wise: assay_sc: - droplet_normal_lung_blood_scanpy.20200205.RC4.h5ad: "10x technology" + droplet_normal_lung_blood_scanpy.20200205.RC4.h5ad: "10x 3' v2" facs_normal_lung_blood_scanpy.20200205.RC4.h5ad: "Smart-seq2" assay_sc_obs_key: assay_differentiation: diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py index 34ba3268d..31d204574 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py @@ -13,10 +13,11 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/james20.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc_obs_key = "assay" self.author = "James" self.disease = "healthy" self.doi_journal = "10.1038/s41590-020-0602-z" + self.individual_obs_key = "donor" self.normalization = "raw" self.organ = "colon" self.organism = "human" @@ -36,5 +37,12 @@ def load(data_dir, **kwargs): adata = anndata.read(fn) adata.X = np.expm1(adata.X) adata.X = adata.X.multiply(scipy.sparse.csc_matrix(adata.obs["n_counts"].values[:, None])).multiply(1 / 10000) - + # Assay maps are described here: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7212050/ + adata.obs["assay"] = [{ + '290b': "10x 3' transcription profiling", + '298c': "10x 3' transcription profiling", + '302c': "10x 3' transcription profiling", + '390c': "10x 5' transcription profiling", + '417c': "10x 5' transcription profiling", + }[x] for x in adata.obs["donor"].values] return adata diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py index c2c655fcd..c6e0b9f3e 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py @@ -17,7 +17,7 @@ def __init__(self, **kwargs): self.download_url_data = f"https://covid19.cog.sanger.ac.uk/{self.sample_fn}" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' transcription profiling" self.author = "Braga" self.disease = "healthy" self.doi_journal = "10.1038/s41591-019-0468-5" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py index 7d7465ef8..07ccb350d 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): self.download_url_meta = \ "https://www.brainimmuneatlas.org/data_files/toDownload/annot_fullAggr.csv" - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Hove" self.disease = "healthy" self.doi_journal = "10.1038/s41593-019-0393-4" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py index 8b0ea9f8f..e5116a21a 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41597_019_0351_8/human_kidney_2020_10xsequencing_liao_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE131nnn/GSE131685/suppl/GSE131685_RAW.tar" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Liao" self.disease = "healthy" self.normalization = "raw" diff --git a/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py b/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py index 91b544321..4f528b133 100644 --- a/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py @@ -12,7 +12,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/voigt19.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v3" self.author = "Voigt" self.disease = "healthy" self.doi_journal = "10.1073/pnas.1914143116" diff --git a/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py b/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py index 0d03da3c2..4103cd1b6 100644 --- a/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py @@ -21,7 +21,7 @@ def __init__(self, **kwargs): organ = self.sample_fn.split("_")[1].split(".")[0] - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' transcription profiling" self.author = "Wang" self.disease = "healthy" self.doi_journal = "10.1084/jem.20191130" diff --git a/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py b/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py index 900a25af3..0fa4ae2ee 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py @@ -18,7 +18,7 @@ def __init__(self, **kwargs): self.download_url_data = f"https://covid19.cog.sanger.ac.uk/{self.sample_fn}" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Lukassen" self.disease = "healthy" self.doi_journal = "10.15252/embj.20105114" diff --git a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml index 4edce9df5..171dcbee0 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml @@ -13,7 +13,7 @@ dataset_wise: primary_data: year: 2020 dataset_or_observation_wise: - assay_sc: "10x technology" + assay_sc: "CITE-seq (cell surface protein profiling)" assay_sc_obs_key: assay_differentiation: assay_differentiation_obs_key: diff --git a/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py b/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py index f915f5b23..16b7bda98 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py @@ -88,7 +88,7 @@ def __init__(self, **kwargs): self.normalization = "norm" self.organism = "mouse" self.organ = organ - self.assay_sc = "10x technology" if self.sample_fn.split("-")[3] == "droplet" else "Smart-seq2" + self.assay_sc = "10x 3' v2" if self.sample_fn.split("-")[3] == "droplet" else "Smart-seq2" self.year = 2019 self.sample_source = "primary_tissue" diff --git a/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py b/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py index 5164003b5..973d33b91 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py @@ -6,8 +6,14 @@ class Dataset(DatasetBase): + """ - TODO: add disease from status and diagnosis fields, healthy is "control" + TODO extra meta data in obs2 + + age: columns "Age" contains integer entries and Unknown + diseases: column "Diagnosis" contains entries NSIP, cHP, Control, IPF, ILD, Sarcoidosis + column Tobacco contains entries Y,N + ethnicity: column "Ethnicity" contains entries African_American, Caucasian, Hispanic, Unknown """ def __init__(self, **kwargs): @@ -17,7 +23,10 @@ def __init__(self, **kwargs): "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE135nnn/GSE135893/suppl/GSE135893%5Fgenes%2Etsv%2Egz", "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE135nnn/GSE135893/suppl/GSE135893%5Fbarcodes%2Etsv%2Egz" ] - self.download_url_meta = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE135nnn/GSE135893/suppl/GSE135893%5FIPF%5Fmetadata%2Ecsv%2Egz" + self.download_url_meta = [ + "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE135nnn/GSE135893/suppl/GSE135893%5FIPF%5Fmetadata%2Ecsv%2Egz", + "https://advances.sciencemag.org/highwire/filestream/234522/field_highwire_adjunct_files/2/aba1972_Table_S2.csv", + ] self.author = "Habermann" self.doi_journal = "10.1126/sciadv.aba1972" @@ -25,9 +34,11 @@ def __init__(self, **kwargs): self.normalization = "raw" self.organ = "lung parenchyma" self.organism = "human" - self.assay_sc = "10x technology" + self.assay_sc_obs_key = "Chemistry" self.year = 2020 self.sample_source = "primary_tissue" + self.sex_obs_key = "Gender" + self.tech_sample_obs_key = "Sample_Name" self.gene_id_symbols_var_key = "index" @@ -43,11 +54,17 @@ def load(data_dir, **kwargs): os.path.join(data_dir, "GSE135893_genes.tsv.gz"), os.path.join(data_dir, "GSE135893_barcodes.tsv.gz"), os.path.join(data_dir, "GSE135893_IPF_metadata.csv.gz"), + os.path.join(data_dir, "aba1972_Table_S2.csv"), ] adata = anndata.read_mtx(fn[0]).T adata.var = pd.read_csv(fn[1], index_col=0, header=None, names=["ids"]) adata.obs = pd.read_csv(fn[2], index_col=0, header=None, names=["barcodes"]) obs = pd.read_csv(fn[3], index_col=0) + obs2 = pd.read_csv(fn[4], index_col=0) + obs["Chemistry"] = [{"3_prime_V2": "10x 3' v2", "5_prime": "10x 5' v1"}[obs2.loc[x, "Chemistry"]] + for x in obs["orig.ident"].values] + obs["Gender"] = [{"F": "female", "M": "male", "Unknown": "unknown"}[obs2.loc[x, "Gender"]] + for x in obs["orig.ident"].values] adata = adata[obs.index.tolist(), :].copy() adata.obs = obs diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py b/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py index 8233ad6b6..a4e2781d5 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py @@ -7,6 +7,10 @@ class Dataset(DatasetBase): + """ + TODO transform field development to controlled field age + """ + def __init__(self, **kwargs): super().__init__(**kwargs) self.download_url_data = [ @@ -15,7 +19,7 @@ def __init__(self, **kwargs): ] self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Stewart" self.disease = "healthy" self.doi_journal = "10.1126/science.aat5031" @@ -40,6 +44,7 @@ def load(data_dir, **kwargs): ] adult = anndata.read(fn[0]) fetal = anndata.read(fn[1]) + # TODO this is is not a controlled field adult.obs["development"] = "adult" fetal.obs["development"] = "fetal" adata = adult.concatenate(fetal) diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py b/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py index 6d41a10ae..d304fc729 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py @@ -7,6 +7,13 @@ class Dataset(DatasetBase): + """ + TODO add meta data + + .obs columns Age contains entries ['3m', '6m', '7w', '8w', '9w', '10m', '10w', '11w', '12w', '13w', '13y', + '14w', '15m', '16w', '17w', '24y', '30m', '35y'] + """ + def __init__(self, **kwargs): super().__init__(**kwargs) self.download_url_data = "https://covid19.cog.sanger.ac.uk/park20.processed.h5ad" @@ -16,10 +23,13 @@ def __init__(self, **kwargs): self.author = "Park" self.disease = "healthy" self.doi_journal = "10.1126/science.aay3224" + self.individual_obs_key = "donor" self.normalization = "norm" self.organ = "thymus" self.organism = "human" self.sample_source = "primary_tissue" + self.sex_obs_key = "Gender" + self.tech_sample_obs_key = "Sample" self.year = 2020 self.gene_id_symbols_var_key = "index" @@ -32,5 +42,6 @@ def load(data_dir, **kwargs): fn = os.path.join(data_dir, "park20.processed.h5ad") adata = anndata.read(fn) adata.X = np.expm1(adata.X) + adata.obs["Gender"] = [{"Male": "male", "Female": "female"}[x] for x in adata.obs["Gender"].values] return adata diff --git a/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py b/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py index 5762b926b..20c5d99fe 100644 --- a/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py @@ -33,7 +33,7 @@ def __init__(self, **kwargs): self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Madissoon" self.disease = "healthy" self.doi_journal = "10.1186/s13059-019-1906-x" diff --git a/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py b/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py index 3cecb29b4..2ad219b55 100644 --- a/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py +++ b/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py @@ -14,7 +14,7 @@ def __init__(self, **kwargs): self.download_url_data = "https://covid19.cog.sanger.ac.uk/lukowski19.processed.h5ad" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Lukowski" self.disease = "healthy" self.doi_journal = "10.15252/embj.2018100811" diff --git a/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py b/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py index 2358175df..b2c17a86b 100644 --- a/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py +++ b/sfaira/data/dataloaders/loaders/dno_doi_10x_genomics/human_blood_2019_10xsequencing_10xgenomics_001.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): "http://cf.10xgenomics.com/samples/cell-exp/3.0.0/pbmc_10k_v3/pbmc_10k_v3_filtered_feature_bc_matrix.h5" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v3" self.author = "10x Genomics" self.disease = "healthy" self.doi_journal = "no_doi_10x_genomics" diff --git a/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py b/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py index b7cdb457f..af23a4cbc 100644 --- a/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py +++ b/sfaira/data/dataloaders/loaders/dno_doi_regev/human_x_2018_10xsequencing_regev_001.py @@ -6,13 +6,18 @@ class Dataset(DatasetBase): + """ + TODO data link is outdated. Maybe update to this + https://data.humancellatlas.org/explore/projects/cc95ff89-2e68-4a08-a234-480eca21ce79/project-matrices. + """ + def __init__(self, **kwargs): super().__init__(**kwargs) self.download_url_data = "https://data.humancellatlas.org/project-assets/project-matrices/" \ "cc95ff89-2e68-4a08-a234-480eca21ce79.homo_sapiens.loom" self.download_url_meta = None - self.assay_sc = "10x technology" + self.assay_sc = "10x 3' v2" self.author = "Regev" self.disease = "healthy" self.doi_journal = "no_doi_regev" diff --git a/sfaira/data/dataloaders/loaders/super_group.py b/sfaira/data/dataloaders/loaders/super_group.py index bbec57dfc..2ee26a1b1 100644 --- a/sfaira/data/dataloaders/loaders/super_group.py +++ b/sfaira/data/dataloaders/loaders/super_group.py @@ -2,7 +2,7 @@ import os from typing import List, Union from warnings import warn -from sfaira.data import DatasetSuperGroup, DatasetGroupDirectoryOriented +from sfaira.data.dataloaders.base.dataset_group import DatasetSuperGroup, DatasetGroupDirectoryOriented class DatasetSuperGroupLoaders(DatasetSuperGroup): diff --git a/sfaira/data/dataloaders/super_group.py b/sfaira/data/dataloaders/super_group.py index 316560f8a..9e7d3c3d2 100644 --- a/sfaira/data/dataloaders/super_group.py +++ b/sfaira/data/dataloaders/super_group.py @@ -7,7 +7,7 @@ from sfaira.data.dataloaders.loaders import DatasetSuperGroupLoaders from sfaira.data.dataloaders.databases import DatasetSuperGroupDatabases -from sfaira.data import DatasetSuperGroup +from sfaira.data.dataloaders.base.dataset_group import DatasetSuperGroup class Universe(DatasetSuperGroup): @@ -26,18 +26,23 @@ def __init__( :param meta_path: :param cache_path: """ + # TODO development flag excluding data bases from universes until this interface is finished. + exclude_databases = True dsgs = [ DatasetSuperGroupLoaders( data_path=data_path, meta_path=meta_path, cache_path=cache_path, ), - DatasetSuperGroupDatabases( - data_path=data_path, - meta_path=meta_path, - cache_path=cache_path, - ) ] + if not exclude_databases: + dsgs.append( + DatasetSuperGroupDatabases( + data_path=data_path, + meta_path=meta_path, + cache_path=cache_path, + ) + ) if sfairae is not None: dsgs.append(sfairae.data.dataloaders.loaders.DatasetSuperGroupLoaders( data_path=data_path, diff --git a/sfaira/data/interactive/loader.py b/sfaira/data/interactive/loader.py index 9df419c7a..ff4f459b1 100644 --- a/sfaira/data/interactive/loader.py +++ b/sfaira/data/interactive/loader.py @@ -39,7 +39,7 @@ def __init__( self.id = dataset_id self.author = "interactive_dataset" - self.doi = "interactive_dataset" + self.doi_journal = "interactive_dataset" self.download_url_data = "." self.download_url_meta = "." diff --git a/sfaira/data/store/__init__.py b/sfaira/data/store/__init__.py new file mode 100644 index 000000000..48f12be63 --- /dev/null +++ b/sfaira/data/store/__init__.py @@ -0,0 +1,3 @@ +from sfaira.data.store.multi_store import load_store, DistributedStoreMultipleFeatureSpaceBase, \ + DistributedStoresH5ad, DistributedStoresDao +from sfaira.data.store.single_store import DistributedStoreSingleFeatureSpace diff --git a/sfaira/data/base/io_dao.py b/sfaira/data/store/io_dao.py similarity index 94% rename from sfaira/data/base/io_dao.py rename to sfaira/data/store/io_dao.py index 87402d1be..55f7fcab6 100644 --- a/sfaira/data/base/io_dao.py +++ b/sfaira/data/store/io_dao.py @@ -84,7 +84,7 @@ def write_dao(store: Union[str, Path], adata: anndata.AnnData, chunks: Union[boo def read_dao(store: Union[str, Path], use_dask: bool = True, columns: Union[None, List[str]] = None, - obs_separate: bool = False) -> \ + obs_separate: bool = False, x_separate: bool = False) -> \ Union[Tuple[anndata.AnnData, Union[dask.dataframe.DataFrame, pd.DataFrame]], anndata.AnnData]: """ Assembles an AnnData instance based on distributed access optimised ("dao") store of a dataset. @@ -118,16 +118,18 @@ def read_dao(store: Union[str, Path], use_dask: bool = True, columns: Union[None obs = pd.read_parquet(path_obs(store), columns=columns, engine="pyarrow") var = pd.read_parquet(path_var(store), engine="pyarrow") # Convert to categorical variables where possible to save memory: - for k, dtype in zip(list(obs.columns), obs.dtypes): - if dtype == "object": - obs[k] = obs[k].astype(dtype="category") + # for k, dtype in zip(list(obs.columns), obs.dtypes): + # if dtype == "object": + # obs[k] = obs[k].astype(dtype="category") d = {"var": var, "uns": uns} # Assemble AnnData without obs to save memory: adata = anndata.AnnData(**d, shape=x.shape) # Need to add these attributes after initialisation so that they are not evaluated: - adata.X = x + if not x_separate: + adata.X = x + if not obs_separate: + adata.obs = obs if obs_separate: return adata, obs - else: - adata.obs = obs - return adata + if x_separate: + return adata, x diff --git a/sfaira/data/store/multi_store.py b/sfaira/data/store/multi_store.py new file mode 100644 index 000000000..dc4e51521 --- /dev/null +++ b/sfaira/data/store/multi_store.py @@ -0,0 +1,338 @@ +import abc +import anndata +import numpy as np +import os +import pickle +from typing import Dict, List, Tuple, Union + +from sfaira.consts import AdataIdsSfaira +from sfaira.data.store.single_store import DistributedStoreSingleFeatureSpace, \ + DistributedStoreDao, DistributedStoreH5ad +from sfaira.data.store.io_dao import read_dao +from sfaira.versions.genomes.genomes import GenomeContainer + + +class DistributedStoreMultipleFeatureSpaceBase(abc.ABC): + + """ + Umbrella class for a dictionary over multiple instances DistributedStoreSingleFeatureSpace. + + Allows for operations on data sets that are defined in different feature spaces. + """ + + _adata_ids_sfaira: AdataIdsSfaira + _stores: Dict[str, DistributedStoreSingleFeatureSpace] + + def __init__(self, stores: Dict[str, DistributedStoreSingleFeatureSpace]): + self._stores = stores + + @property + def stores(self) -> Dict[str, DistributedStoreSingleFeatureSpace]: + """ + Only expose stores that contain observations. + """ + return dict([(k, v) for k, v in self._stores.items() if v.n_obs > 0]) + + @stores.setter + def stores(self, x: Dict[str, DistributedStoreSingleFeatureSpace]): + raise NotImplementedError("cannot set this attribute, it s defined in constructor") + + @property + def genome_containers(self) -> Dict[str, Union[GenomeContainer, None]]: + return dict([(k, v.genome_container) for k, v in self._stores.items()]) + + @genome_containers.setter + def genome_containers(self, x: Union[GenomeContainer, Dict[str, GenomeContainer]]): + if isinstance(x, GenomeContainer): + # Transform into dictionary first. + organisms = [k for k, v in self.stores.items()] + if isinstance(organisms, list) and len(organisms) == 0: + raise Warning("found empty organism lists in genome_container.setter") + if len(organisms) > 1: + raise ValueError(f"Gave a single GenomeContainer for a store instance that has mulitiple organism: " + f"{organisms}, either further subset the store or give a dictionary of " + f"GenomeContainers") + else: + x = {organisms[0]: x} + for k, v in x.items(): + self.stores[k].genome_container = v + + @property + def indices(self) -> Dict[str, np.ndarray]: + """ + Dictionary of indices of selected observations contained in all stores. + """ + return dict([(kk, vv) for k, v in self.stores.items() for kk, vv in v.indices.items()]) + + @property + def adata_by_key(self) -> Dict[str, anndata.AnnData]: + """ + Dictionary of all anndata instances for each selected data set in store, sub-setted by selected cells, for each + stores. + """ + return dict([(kk, vv) for k, v in self.stores.items() for kk, vv in v.adata_by_key.items()]) + + @property + def data_by_key(self): + """ + Data matrix for each selected data set in store, sub-setted by selected cells. + """ + return dict([(kk, vv) for k, v in self.stores.items() for kk, vv in v.data_by_key.items()]) + + @property + def var_names(self) -> Dict[str, List[str]]: + """ + Dictionary of variable names by store. + """ + return dict([(k, v.var_names) for k, v in self.stores.items()]) + + @property + def n_vars(self) -> Dict[str, int]: + """ + Dictionary of number of features by store. + """ + return dict([(k, v.n_vars) for k, v in self.stores.items()]) + + @property + def n_obs(self) -> Dict[str, int]: + """ + Dictionary of number of observations by store. + """ + return dict([(k, v.n_obs) for k, v in self.stores.items()]) + + @property + def obs(self): + """ + Dictionary of concatenated .obs tables by store, only including non-empty stores. + """ + return dict([(k, v.obs) for k, v in self.stores.items()]) + + @property + def X(self): + """ + Dictionary of concatenated data matrices by store, only including non-empty stores. + """ + return dict([(k, v.X) for k, v in self.stores.items()]) + + @property + def shape(self) -> Dict[str, Tuple[int, int]]: + """ + Dictionary of full selected data matrix shape by store. + """ + return dict([(k, v.shape) for k, v in self.stores.items()]) + + def subset(self, attr_key, values: Union[str, List[str], None] = None, + excluded_values: Union[str, List[str], None] = None, verbose: int = 1): + """ + Subset list of adata objects based on cell-wise properties. + + Subsetting is done based on index vectors, the objects remain untouched. + + :param attr_key: Property to subset by. Options: + + - "assay_differentiation" points to self.assay_differentiation_obs_key + - "assay_sc" points to self.assay_sc_obs_key + - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key + - "cell_line" points to self.cell_line + - "cellontology_class" points to self.cellontology_class_obs_key + - "developmental_stage" points to self.developmental_stage_obs_key + - "ethnicity" points to self.ethnicity_obs_key + - "organ" points to self.organ_obs_key + - "organism" points to self.organism_obs_key + - "sample_source" points to self.sample_source_obs_key + - "sex" points to self.sex_obs_key + - "state_exact" points to self.state_exact_obs_key + :param values: Classes to overlap to. Supply either values or excluded_values. + :param excluded_values: Classes to exclude from match list. Supply either values or excluded_values. + """ + for k in self.stores.keys(): + self.stores[k].subset(attr_key=attr_key, values=values, excluded_values=excluded_values, verbose=0) + if self.n_obs == 0 and verbose > 0: + print("WARNING: multi store is now empty.") + + def write_config(self, fn: Union[str, os.PathLike]): + """ + Writes a config file that describes the current data sub-setting. + + This config file can be loaded later to recreate a sub-setting. + This config file contains observation-wise subsetting information. + + :param fn: Output file without file type extension. + """ + indices = {} + for v in self.stores.values(): + indices.update(v.indices) + with open(fn + '.pickle', 'wb') as f: + pickle.dump(indices, f) + + def load_config(self, fn: Union[str, os.PathLike]): + """ + Load a config file and recreates a data sub-setting. + This config file contains observation-wise subsetting information. + + :param fn: Output file without file type extension. + """ + with open(fn, 'rb') as f: + indices = pickle.load(f) + # Distribute indices to corresponding stores by matched keys. + keys_not_found = list(indices.keys()) + for k, v in self.stores.items(): + indices_k = {} + for i, (kk, vv) in enumerate(indices.items()): + if kk in v.adata_by_key.keys(): + indices_k[kk] = vv + del keys_not_found[i] + self.stores[k].indices = indices_k + # Make sure all declared data were assigned to stores: + if len(keys_not_found) > 0: + raise ValueError(f"did not find object(s) with name(s) in store: {keys_not_found}") + + def generator( + self, + idx: Union[Dict[str, Union[np.ndarray, None]], None] = None, + intercalated: bool = True, + **kwargs + ) -> Tuple[iter, int]: + """ + Emission of batches from unbiased generators of all stores. + + See also DistributedStore*.generator(). + + :param idx: + :param intercalated: Whether to do sequential or intercalated emission. + :param kwargs: See parameters of DistributedStore*.generator(). + """ + if idx is None: + idx = dict([(k, None) for k in self.stores.keys()]) + for k in self.stores.keys(): + assert k in idx.keys(), (idx.keys(), self.stores.keys()) + generators = [ + v.generator(idx=idx[k], **kwargs) + for k, v in self.stores.items() + ] + generator_fns = [x[0]() for x in generators] + generator_len = [x[1] for x in generators] + + if intercalated: + # Define relative drawing frequencies from iterators for intercalation. + ratio = np.asarray(np.round(np.max(generator_len) / np.asarray(generator_len), 0), dtype="int64") + + def generator(): + # Document which generators are still yielding batches: + yielding = np.ones((ratio.shape[0],)) == 1. + while np.any(yielding): + # Loop over one iterator length adjusted cycle of emissions. + for i, (g, n) in enumerate(zip(generator_fns, ratio)): + for _ in range(n): + try: + x = next(g) + yield x + except StopIteration: + yielding[i] = False + else: + def generator(): + for g in generator_fns: + for x in g(): + yield x + + return generator, int(np.sum(generator_len)) + + +class DistributedStoresDao(DistributedStoreMultipleFeatureSpaceBase): + + _dataset_weights: Union[None, Dict[str, float]] + + def __init__(self, cache_path: Union[str, os.PathLike], columns: Union[None, List[str]] = None): + """ + + :param cache_path: Store directory. + :param columns: Which columns to read into the obs copy in the output, see pandas.read_parquet(). + """ + # Collect all data loaders from files in directory: + self._adata_ids_sfaira = AdataIdsSfaira() + adata_by_key = {} + x_by_key = {} + indices = {} + for f in np.sort(os.listdir(cache_path)): + adata = None + x = None + trial_path = os.path.join(cache_path, f) + if os.path.isdir(trial_path): + # zarr-backed anndata are saved as directories with the elements of the array group as further sub + # directories, e.g. a directory called "X", and a file ".zgroup" which identifies the zarr group. + adata, x = read_dao(trial_path, use_dask=True, columns=columns, obs_separate=False, x_separate=True) + if adata is not None: + organism = adata.uns[self._adata_ids_sfaira.organism] + if organism not in adata_by_key.keys(): + adata_by_key[organism] = {} + x_by_key[organism] = {} + indices[organism] = {} + adata_by_key[organism][adata.uns["id"]] = adata + x_by_key[organism][adata.uns["id"]] = x + indices[organism][adata.uns["id"]] = np.arange(0, adata.n_obs) + self._x_by_key = x_by_key + stores = dict([ + (k, DistributedStoreDao(adata_by_key=adata_by_key[k], x_by_key=x_by_key[k], indices=indices[k], + obs_by_key=None)) + for k in adata_by_key.keys() + ]) + super(DistributedStoresDao, self).__init__(stores=stores) + + +class DistributedStoresH5ad(DistributedStoreMultipleFeatureSpaceBase): + + def __init__(self, cache_path: Union[str, os.PathLike], in_memory: bool = False): + # Collect all data loaders from files in directory: + self._adata_ids_sfaira = AdataIdsSfaira() + adata_by_key = {} + indices = {} + for f in np.sort(os.listdir(cache_path)): + adata = None + trial_path = os.path.join(cache_path, f) + if os.path.isfile(trial_path): + # Narrow down to supported file types: + if f.split(".")[-1] == "h5ad": + try: + adata = anndata.read_h5ad( + filename=trial_path, + backed="r" if in_memory else None, + ) + except OSError as e: + adata = None + print(f"WARNING: for data set {f}: {e}") + if adata is not None: + organism = adata.uns[self._adata_ids_sfaira.organism] + if organism not in adata_by_key.keys(): + adata_by_key[organism] = {} + indices[organism] = {} + adata_by_key[organism][adata.uns["id"]] = adata + indices[organism][adata.uns["id"]] = np.arange(0, adata.n_obs) + stores = dict([ + (k, DistributedStoreH5ad(adata_by_key=adata_by_key[k], indices=indices[k], in_memory=in_memory)) + for k in adata_by_key.keys() + ]) + super(DistributedStoresH5ad, self).__init__(stores=stores) + + +def load_store(cache_path: Union[str, os.PathLike], store_format: str = "dao", + columns: Union[None, List[str]] = None) -> Union[DistributedStoresH5ad, DistributedStoresDao]: + """ + Instantiates a distributed store class. + + :param cache_path: Store directory. + :param store_format: Format of store {"h5ad", "dao"}. + + - "h5ad": Returns instance of DistributedStoreH5ad. + - "dao": Returns instance of DistributedStoreDoa (distributed access optimized). + :param columns: Which columns to read into the obs copy in the output, see pandas.read_parquet(). + Only relevant if store_format is "dao". + :return: Instances of a distributed store class. + """ + if store_format == "anndata": + return DistributedStoresH5ad(cache_path=cache_path, in_memory=True) + elif store_format == "dao": + return DistributedStoresDao(cache_path=cache_path, columns=columns) + elif store_format == "h5ad": + return DistributedStoresH5ad(cache_path=cache_path, in_memory=False) + else: + raise ValueError(f"Did not recognize store_format {store_format}.") diff --git a/sfaira/data/base/distributed_store.py b/sfaira/data/store/single_store.py similarity index 53% rename from sfaira/data/base/distributed_store.py rename to sfaira/data/store/single_store.py index 310124ef6..c9e8e66fe 100644 --- a/sfaira/data/base/distributed_store.py +++ b/sfaira/data/store/single_store.py @@ -2,18 +2,19 @@ import anndata import dask.array import dask.dataframe +import h5py import numpy as np import os import pandas as pd import pickle import scipy.sparse import sys +import time from typing import Dict, List, Tuple, Union from sfaira.consts import AdataIdsSfaira, OCS -from sfaira.data.base.dataset import is_child, UNS_STRING_META_IN_OBS -from sfaira.data.base.io_dao import read_dao -from sfaira.versions.genomes import GenomeContainer +from sfaira.data.dataloaders.base.utils import is_child, UNS_STRING_META_IN_OBS +from sfaira.versions.genomes.genomes import GenomeContainer """ Distributed stores are array-like classes that sit on groups of on disk representations of anndata instances files. @@ -21,27 +22,30 @@ In particular, if .X is saved as zarr array, one can use lazy dask arrays to operate across sets of count matrices, heavily reducing the complexity of the code required here and often increasing access speed. +This instances sit on groups of data objects that have the same feature space, e.g. are from the same organism. +You can operate on multiple organisms at the same time by using an umbrella class over a set of such instances by using +DistributedStoreMultipleFeatureSpaceBase. + DistributedStoreBase is base class for any file format on disk. -DistributedStoreZarr is adapted to classes that store an anndata instance as a zarr group. -DistributedStoreH5ad is adapted to classes that store an anndata instance as a h5ad file. +DistributedStoreDao wraps an on-disk representation of anndata instance in the sfaira "dao" format. +DistributedStoreH5ad wraps an on-disk representation of anndata instances as a h5ad file. +DistributedStoreAnndata wraps in-memory anndata instance. Note that in all cases, you can use standard anndata reading functions to load a single object into memory. """ -def access_helper(adata, s, e, j, return_dense, obs_keys) -> tuple: - x = adata.X[s:e, :] - # Do dense conversion now so that col-wise indexing is not slow, often, dense conversion - # would be done later anyway. - if return_dense and isinstance(x, scipy.sparse.spmatrix): - x = x.todense() - if j is not None: - x = x[:, j] - obs = adata.obs[obs_keys].iloc[s:e, :] - return x, obs +def _process_batch_size(x: int, idx: np.ndarray) -> int: + if x > len(idx): + batch_size_new = len(idx) + print(f"WARNING: reducing retrieval batch size according to data availability in store " + f"from {x} to {batch_size_new}") + x = batch_size_new + return x + +class DistributedStoreSingleFeatureSpace: -class DistributedStoreBase(abc.ABC): """ Data set group class tailored to data access requirements common in high-performance computing (HPC). @@ -50,34 +54,47 @@ class DistributedStoreBase(abc.ABC): .adata_by_key is a dictionary (by id) of backed anndata instances that point to individual h5ads. This dictionary is intialised with all h5ads in the store. - As the store is subsetted, key-value pairs are deleted from this dictionary. + As the store is sub-setted, key-value pairs are deleted from this dictionary. .indices have keys that correspond to keys in .adata_by_key and contain index vectors of observations in the anndata instances in .adata_by_key which are still kept. These index vectors are a form of lazy slicing that does not require data set loading or re-writing. - As the store is subsetted, key-value pairs are deleted from this dictionary if no observations from a given key - match the subsetting. + As the store is sub-setted, key-value pairs are deleted from this dictionary if no observations from a given key + match the sub-setting. If a subset of observations from a key matches the subsetting operation, the index set in the corresponding value is reduced. - All data retrievel operations work on .indices: Generators run over these indices when retrieving observations for + All data retrieval operations work on .indices: Generators run over these indices when retrieving observations for example. """ _adata_by_key: Dict[str, anndata.AnnData] _indices: Dict[str, np.ndarray] _obs_by_key: Union[None, Dict[str, dask.dataframe.DataFrame]] + data_source: str def __init__(self, adata_by_key: Dict[str, anndata.AnnData], indices: Dict[str, np.ndarray], - obs_by_key: Union[None, Dict[str, dask.dataframe.DataFrame]] = None): + obs_by_key: Union[None, Dict[str, dask.dataframe.DataFrame]] = None, data_source: str = "X"): self.adata_by_key = adata_by_key self.indices = indices self.obs_by_key = obs_by_key self.ontology_container = OCS self._genome_container = None self._adata_ids_sfaira = AdataIdsSfaira() + self.data_source = data_source self._celltype_universe = None + @property + def idx(self) -> np.ndarray: + """ + Global indices. + """ + idx_global = np.arange(0, np.sum([len(v) for v in self.indices.values()])) + return idx_global + def _validate_idx(self, idx: Union[np.ndarray, list]) -> np.ndarray: + """ + Validate global index vector. + """ assert np.max(idx) < self.n_obs, f"maximum of supplied index vector {np.max(idx)} exceeds number of modelled " \ f"observations {self.n_obs}" assert len(idx) == len(np.unique(idx)), f"there were {len(idx) - len(np.unique(idx))} repeated indices in idx" @@ -90,49 +107,73 @@ def _validate_idx(self, idx: Union[np.ndarray, list]) -> np.ndarray: idx = np.asarray(idx) return idx + @property + def organisms_by_key(self) -> Dict[str, str]: + """ + Data set-wise organism label as dictionary of data set keys. + """ + ks = self.indices.keys() + organisms = [self._adata_by_key[k].uns[self._adata_ids_sfaira.organism] for k in ks] + # Flatten list, assumes that each data set maps to one organism: + organisms = [x[0] if (isinstance(x, list) or isinstance(x, tuple)) else x for x in organisms] + return dict(list(zip(ks, organisms))) + + @property + def organism(self): + """ + Organism of store. + """ + organisms = np.sort(np.unique(list(self.organisms_by_key.values()))) + assert len(organisms) == 1, organisms + return organisms[0] + def _validate_feature_space_homogeneity(self) -> List[str]: """ Assert that the data sets which were kept have the same feature names. + + :return: List of feature names in shared feature space or dictionary of list of features. """ - var_names = self._adata_by_key[list(self.indices.keys())[0]].var_names.tolist() - for k, v in self.indices.items(): + reference_k = list(self._adata_by_key.keys())[0] + var_names = self._adata_by_key[reference_k].var_names.tolist() + for k in list(self._adata_by_key.keys()): assert len(var_names) == len(self._adata_by_key[k].var_names), \ - f"number of features in store differed in object {k} compared to {list(self._adata_by_key.keys())[0]}" + f"number of features in store differed in object {k} compared to {reference_k}" assert np.all(var_names == self._adata_by_key[k].var_names), \ - f"var_names in store were not matched in object {k} compared to {list(self._adata_by_key.keys())[0]}" + f"var_names in store were not matched in object {k} compared to {reference_k}" return var_names - def _generator_helper( - self, - idx: Union[np.ndarray, None] = None, - ) -> Tuple[Union[np.ndarray, None], Union[np.ndarray, None]]: - # Make sure that features are ordered in the same way in each object so that generator yields consistent cell - # vectors. - _ = self._validate_feature_space_homogeneity() - var_names_store = self.adata_by_key[list(self.indices.keys())[0]].var_names.tolist() - # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. - if self.genome_container is not None: - var_names_target = self.genome_container.ensembl - var_idx = np.sort([var_names_store.index(x) for x in var_names_target]) - # Check if index vector is just full ordered list of indices, in this case, sub-setting is unnecessary. - if len(var_idx) == len(var_names_store) and np.all(var_idx == np.arange(0, len(var_names_store))): - var_idx = None - else: - var_idx = None - if idx is not None: - idx = self._validate_idx(idx) - return idx, var_idx - @property def adata_by_key(self) -> Dict[str, anndata.AnnData]: + """ + Anndata instance for each selected data set in store, sub-setted by selected cells. + """ return self._adata_by_key @adata_by_key.setter def adata_by_key(self, x: Dict[str, anndata.AnnData]): self._adata_by_key = x + @property + def data_by_key(self): + """ + Data matrix for each selected data set in store, sub-setted by selected cells. + """ + return dict([(k, v.X) for k, v in self.adata_by_key.items()]) + + @property + def adata_memory_footprint(self) -> Dict[str, float]: + """ + Memory foot-print of data set k in MB. + """ + return dict([(k, sys.getsizeof(v) / np.power(1024, 2)) for k, v in self.adata_by_key.items()]) + @property def indices(self) -> Dict[str, np.ndarray]: + """ + Indices of observations that are currently exposed in adata of this instance. + + This depends on previous subsetting. + """ return self._indices @indices.setter @@ -176,7 +217,7 @@ def genome_container(self) -> Union[GenomeContainer, None]: return self._genome_container @genome_container.setter - def genome_container(self, x: GenomeContainer): + def genome_container(self, x: Union[GenomeContainer]): var_names = self._validate_feature_space_homogeneity() # Validate genome container choice: # Make sure that all var names defined in genome container are also contained in loaded data sets. @@ -184,6 +225,16 @@ def genome_container(self, x: GenomeContainer): "did not find variable names from genome container in store" self._genome_container = x + @property + def dataset_weights(self): + return self._dataset_weights + + @dataset_weights.setter + def dataset_weights(self, x: Dict[str, float]): + assert np.all([k in self.adata_by_key.keys() for k in x.keys()]), "did not recognize some keys" + assert np.all([k in x.keys() for k in self.indices.keys()]), "some data sets in index were omitted" + self._dataset_weights = x + def get_subset_idx(self, attr_key, values: Union[str, List[str], None], excluded_values: Union[str, List[str], None]) -> dict: """ @@ -216,7 +267,8 @@ def get_idx(adata, obs, k, v, xv, dataset): # Use cell-wise annotation if data set-wide maps are ambiguous: # This can happen if the different cell-wise annotations are summarised as a union in .uns. if getattr(self._adata_ids_sfaira, k) in adata.uns.keys() and \ - adata.uns[getattr(self._adata_ids_sfaira, k)] != UNS_STRING_META_IN_OBS: + adata.uns[getattr(self._adata_ids_sfaira, k)] != UNS_STRING_META_IN_OBS and \ + getattr(self._adata_ids_sfaira, k) not in obs.columns: values_found = adata.uns[getattr(self._adata_ids_sfaira, k)] if isinstance(values_found, np.ndarray): values_found = values_found.tolist() @@ -256,9 +308,6 @@ def get_idx(adata, obs, k, v, xv, dataset): ]) ] idx = np.where([x in values_found_unique_matched for x in values_found])[0] - if len(idx) > 0: - # TODO keep this logging for now to catch undesired behaviour resulting from loaded edges in ontologies. - print(f"matched keys {str(values_found_unique_matched)} in data set {dataset}") return idx indices = {} @@ -279,7 +328,7 @@ def get_idx(adata, obs, k, v, xv, dataset): return indices def subset(self, attr_key, values: Union[str, List[str], None] = None, - excluded_values: Union[str, List[str], None] = None): + excluded_values: Union[str, List[str], None] = None, verbose: int = 1): """ Subset list of adata objects based on cell-wise properties. @@ -303,7 +352,7 @@ def subset(self, attr_key, values: Union[str, List[str], None] = None, :param excluded_values: Classes to exclude from match list. Supply either values or excluded_values. """ self.indices = self.get_subset_idx(attr_key=attr_key, values=values, excluded_values=excluded_values) - if self.n_obs == 0: + if self.n_obs == 0 and verbose > 0: print("WARNING: store is now empty.") def write_config(self, fn: Union[str, os.PathLike]): @@ -333,7 +382,10 @@ def load_config(self, fn: Union[str, os.PathLike]): raise ValueError(f"did not find object with name {x} in currently loaded universe") @property - def var_names(self): + def var_names(self) -> List[str]: + """ + Feature names of selected genes by organism in store. + """ var_names = self._validate_feature_space_homogeneity() # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. if self.genome_container is None: @@ -342,7 +394,10 @@ def var_names(self): return self.genome_container.ensembl @property - def n_vars(self): + def n_vars(self) -> int: + """ + Number of selected features per organism in store + """ var_names = self._validate_feature_space_homogeneity() # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. if self.genome_container is None: @@ -351,14 +406,47 @@ def n_vars(self): return self.genome_container.n_var @property - def n_obs(self): + def n_obs(self) -> int: + """ + Number of observations selected in store. + """ return np.sum([len(v) for v in self.indices.values()]) @property - def shape(self): - return [self.n_obs, self.n_vars] + def shape(self) -> Tuple[int, int]: + return self.n_obs, self.n_vars @abc.abstractmethod + def _generator( + self, + idx_gen: iter, + var_idx: Union[np.ndarray, None], + obs_keys: List[str], + ) -> iter: + pass + + def _generator_helper( + self, + idx: Union[np.ndarray, None], + batch_size: int, + ) -> Tuple[Union[np.ndarray, None], Union[np.ndarray, None], int]: + # Make sure that features are ordered in the same way in each object so that generator yields consistent cell + # vectors. + var_names = self._validate_feature_space_homogeneity() + # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. + if self.genome_container is not None: + var_names_target = self.genome_container.ensembl + var_idx = np.sort([var_names.index(x) for x in var_names_target]) + # Check if index vector is just full ordered list of indices, in this case, sub-setting is unnecessary. + if len(var_idx) == len(var_names) and np.all(var_idx == np.arange(0, len(var_names))): + var_idx = None + else: + var_idx = None + if idx is not None: + idx = self._validate_idx(idx) + batch_size = _process_batch_size(x=batch_size, idx=idx) + return idx, var_idx, batch_size + def generator( self, idx: Union[np.ndarray, None] = None, @@ -366,12 +454,137 @@ def generator( obs_keys: List[str] = [], return_dense: bool = True, randomized_batch_access: bool = False, + random_access: bool = False, + **kwargs ) -> iter: - pass + """ + Yields an unbiased generator over observations in the contained data sets. + + :param idx: Global idx to query from store. These is an array with indicies corresponding to a contiuous index + along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of + self.adata_by_key. + :param batch_size: Number of observations read from disk in each batched access (generator invocation). + :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available + in self.adata_by_key. + :param return_dense: Whether to force return count data .X as dense batches. This allows more efficient feature + indexing if the store is sparse (column indexing on csr matrices is slow). + :param randomized_batch_access: Whether to randomize batches during reading (in generator). Lifts necessity of + using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there + is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest + changes in batch composition. + Do not use randomized_batch_access and random_access. + :param random_access: Whether to fully shuffle observations before batched access takes place. May + slow down access compared randomized_batch_access and to no randomization. + Do not use randomized_batch_access and random_access. + :return: Generator function which yields batch_size at every invocation. + The generator returns a tuple of (.X, .obs). + """ + idx, var_idx, batch_size = self._generator_helper(idx=idx, batch_size=batch_size) + if randomized_batch_access and random_access: + raise ValueError("Do not use randomized_batch_access and random_access.") + n_obs = len(idx) + remainder = n_obs % batch_size + n_batches = int(n_obs // batch_size + int(remainder > 0)) + + def idx_gen(): + """ + Yields index objects for one epoch of all data. + + These index objects are used by generators that have access to the data objects to build data batches. + + :returns: Tuple of: + - Ordering of observations in epoch. + - Batch start and end indices for batch based on ordering defined in first output. + """ + batch_starts_ends = [ + (int(x * batch_size), int(np.minimum((x * batch_size) + batch_size, n_obs))) + for x in np.arange(0, n_batches) + ] + batch_range = np.arange(0, len(batch_starts_ends)) + if randomized_batch_access: + np.random.shuffle(batch_range) + batch_starts_ends = [batch_starts_ends[i] for i in batch_range] + obs_idx = idx.copy() + if random_access: + np.random.shuffle(obs_idx) + yield obs_idx, batch_starts_ends + + return self._generator(idx_gen=idx_gen(), var_idx=var_idx, obs_keys=obs_keys), n_batches + + def generator_balanced( + self, + idx: Union[np.ndarray, None] = None, + balance_obs: Union[str, None] = None, + balance_damping: float = 0., + batch_size: int = 1, + obs_keys: List[str] = [], + **kwargs + ) -> iter: + """ + Yields a data set balanced generator. + + Yields one random batch per dataset. Assumes that data sets are annotated in .obs. + Uses self.dataset_weights if this are given to sample data sets with different frequencies. + Can additionally also balance across one meta data annotation within each data set. + + Assume you have a data set with two classes (A=80, B=20 cells) in a column named "cellontology_class". + The single batch for this data set produced by this generator in each epoch contains N cells. + If balance_obs is False, these N cells are the result of a draw without replacement from all 100 cells in this + dataset in which each cell receives the same weight / success probability of 1.0. + If balance_obs is True, these N cells are the result of a draw without replacement from all 100 cells in this + data set with individual success probabilities such that classes are balanced: 0.2 for A and 0.8 for B. + + :param idx: Global idx to query from store. These is an array with indicies corresponding to a contiuous index + along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of + self.adata_by_key. + :param balance_obs: .obs column key to balance samples from each data set over. + Note that each data set must contain this column in its .obs table. + :param balance_damping: Damping to apply to class weighting induced by balance_obs. The class-wise + wise sampling probabilities become `max(balance_damping, (1. - frequency))` + :param batch_size: Number of observations read from disk in each batched access (generator invocation). + :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available + in self.adata_by_key. + :return: Generator function which yields batch_size at every invocation. + The generator returns a tuple of (.X, .obs). + """ + idx, var_idx, batch_size = self._generator_helper(idx=idx, batch_size=batch_size) + + def idx_gen(): + batch_starts_ends = [] + idx_proc = [] + val_dataset = self.obs[self._adata_ids_sfaira.dataset].values + datasets = np.unique(val_dataset) + if self.dataset_weights is not None: + weights = np.array([self.dataset_weights[x] for x in datasets]) + p = weights / np.sum(weights) + datasets = np.random.choice(a=datasets, replace=True, size=len(datasets), p=p) + if balance_obs is not None: + val_meta = self.obs[balance_obs].values + for x in datasets: + idx_x = np.where(val_dataset == x)[0] + n_obs = len(idx_x) + batch_size_o = int(np.minimum(batch_size, n_obs)) + batch_starts_ends.append(np.array([(0, batch_size_o), ])) + if balance_obs is None: + p = np.ones_like(idx_x) / len(idx_x) + else: + if balance_obs not in self.obs.columns: + raise ValueError(f"did not find column {balance_obs} in {self.organism}") + val_meta_x = val_meta[idx_x] + class_freq = dict([(y, np.mean(val_meta_x == y)) for y in np.unique(val_meta_x)]) + class_freq_x_by_obs = np.array([class_freq[y] for y in val_meta_x]) + damped_freq_coefficient = np.maximum(balance_damping, (1. - class_freq_x_by_obs)) + p = np.ones_like(idx_x) / len(idx_x) * damped_freq_coefficient + idx_x_sample = np.random.choice(a=idx_x, replace=False, size=batch_size_o, p=p) + idx_proc.append(idx_x_sample) + idx_proc = np.asarray(idx_proc) + yield idx_proc, batch_starts_ends + + return self._generator(idx_gen=idx_gen(), var_idx=var_idx, obs_keys=obs_keys) @property @abc.abstractmethod - def X(self) -> Union[dask.array.Array, scipy.sparse.csr_matrix]: + def X(self): pass @property @@ -379,37 +592,15 @@ def X(self) -> Union[dask.array.Array, scipy.sparse.csr_matrix]: def obs(self) -> Union[pd.DataFrame]: pass - @abc.abstractmethod - def n_counts(self, idx: Union[np.ndarray, list, None] = None) -> np.ndarray: - pass +class DistributedStoreH5ad(DistributedStoreSingleFeatureSpace): -class DistributedStoreH5ad(DistributedStoreBase): + in_memory: bool - def __init__(self, cache_path: Union[str, os.PathLike]): - # Collect all data loaders from files in directory: - adata_by_key = {} - indices = {} - for f in os.listdir(cache_path): - adata = None - trial_path = os.path.join(cache_path, f) - if os.path.isfile(trial_path): - # Narrow down to supported file types: - if f.split(".")[-1] == "h5ad": - print(f"Discovered {f} as .h5ad file.") - try: - adata = anndata.read_h5ad( - filename=trial_path, - backed="r", - ) - except OSError as e: - adata = None - print(f"WARNING: for data set {f}: {e}") - if adata is not None: - adata_by_key[adata.uns["id"]] = adata - indices[adata.uns["id"]] = np.arange(0, adata.n_obs) + def __init__(self, in_memory: bool, **kwargs): + super(DistributedStoreH5ad, self).__init__(**kwargs) self._x_as_dask = False - super(DistributedStoreH5ad, self).__init__(adata_by_key=adata_by_key, indices=indices) + self.in_memory = in_memory @property def adata_sliced(self) -> Dict[str, anndata.AnnData]: @@ -418,9 +609,31 @@ def adata_sliced(self) -> Dict[str, anndata.AnnData]: """ return dict([(k, self._adata_by_key[k][v, :]) for k, v in self.indices.items()]) + @property + def indices_global(self) -> dict: + """ + Increasing indices across data sets which can be concatenated into a single index vector with unique entries + for cells. + + E.g.: For two data sets of 10 cells each, the return value would be {A:[0..9], B:[10..19]}. + Note that this operates over pre-selected indices, if this store was subsetted before resulting in only the + second half B to be kept, the return value would be {A:[0..9], B:[10..14]}, where .indices would be + {A:[0..9], B:[15..19]}. + """ + counter = 0 + indices = {} + for k, v in self.indices.items(): + indices[k] = np.arange(counter, counter + len(v)) + counter += len(v) + return indices + @property def X(self): - assert False + if self.in_memory: + assert np.all([isinstance(v.X, scipy.sparse.spmatrix) for v in self.adata_by_key.values()]) + return scipy.sparse.vstack([v.X for v in self.adata_by_key.values()]) + else: + raise NotImplementedError() @property def obs(self) -> Union[pd.DataFrame]: @@ -434,186 +647,149 @@ def obs(self) -> Union[pd.DataFrame]: for k, v in self.indices.items() ], axis=0, join="inner", ignore_index=False, copy=False) - def n_counts(self, idx: Union[np.ndarray, list, None] = None) -> np.ndarray: - """ - Compute sum over features for each observation in index. - - :param idx: Index vector over observations in object. - :return: Array with sum per observations: (number of observations in index,) - """ - return np.concatenate([ - np.asarray(v.X.sum(axis=1)).flatten() - for v in self.adata_by_key_subset(idx=idx).values() - ], axis=0) - - def generator( + def _generator( self, - idx: Union[np.ndarray, None] = None, - batch_size: int = 1, + idx_gen: iter, + var_idx: np.ndarray, obs_keys: List[str] = [], - return_dense: bool = True, - randomized_batch_access: bool = False, + return_dense: bool = False, ) -> iter: """ - Yields an unbiased generator over observations in the contained data sets. + Yields data batches as defined by index sets emitted from index generator. - :param idx: Global idx to query from store. These is an array with indicies corresponding to a contiuous index - along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of - self.adata_by_key. - :param batch_size: Number of observations read from disk in each batched access (generator invocation). + :param idx_gen: Generator that yield two elements in each draw: + - np.ndarray: The cells to emit. + - List[Tuple[int, int]: Batch start and end indices. + :param var_idx: The features to emit. :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available in self.adata_by_key. - :param return_dense: Whether to force return count data .X as dense batches. This allows more efficient feature - indexing if the store is sparse (column indexing on csr matrices is slow). - :param randomized_batch_access: Whether to randomize batches during reading (in generator). Lifts necessity of - using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there - is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest - changes in batch composition. :return: Generator function which yields batch_size at every invocation. - The generator returns a tuple of (.X, .obs) with types: - - - if store format is h5ad: (Union[scipy.sparse.csr_matrix, np.ndarray], pandas.DataFrame) + The generator returns a tuple of (.X, .obs). """ - idx, var_idx = self._generator_helper(idx=idx) + adata_sliced = self.adata_sliced + # Speed up access to single object by skipping index overlap operations: + single_object = len(adata_sliced.keys()) == 1 + if not single_object: + idx_dict_global = dict([(k, set(v)) for k, v in self.indices_global.items()]) def generator(): - adatas_sliced_subset = self.adata_by_key_subset(idx=idx) - key_batch_starts_ends = [] # List of tuples of data set key and (start, end) index set of batches. - for k, adata in adatas_sliced_subset.items(): - n_obs = adata.shape[0] - if n_obs > 0: # Skip data objects without matched cells. - # Cells left over after batching to batch size, accounting for overhang: - remainder = n_obs % batch_size - key_batch_starts_ends_k = [ - (k, (int(x * batch_size), int(np.minimum((x * batch_size) + batch_size, n_obs)))) - for x in np.arange(0, n_obs // batch_size + int(remainder > 0)) + for idx, batch_starts_ends in idx_gen: + for s, e in batch_starts_ends: + idx_i = idx[s:e] + # Match adata objects that overlap to batch: + if single_object: + idx_i_dict = dict([(k, np.sort(idx_i)) for k in adata_sliced.keys()]) + else: + idx_i_set = set(idx_i) + idx_i_dict = dict([ + (k, np.sort(list(idx_i_set.intersection(v)))) + for k, v in idx_dict_global.items() + ]) + # Only retain non-empty. + idx_i_dict = dict([(k, v) for k, v in idx_i_dict.items() if len(v) > 0]) + # I) Prepare data matrix. + x = [ + adata_sliced[k].X[v, :] + for k, v in idx_i_dict.items() ] - assert np.sum([v2 - v1 for k, (v1, v2) in key_batch_starts_ends_k]) == n_obs - key_batch_starts_ends.extend(key_batch_starts_ends_k) - batch_range = np.arange(0, len(key_batch_starts_ends)) - if randomized_batch_access: - np.random.shuffle(batch_range) - for i in batch_range: - k, (s, e) = key_batch_starts_ends[i] - x, obs = access_helper(adata=adatas_sliced_subset[k], s=s, e=e, j=var_idx, return_dense=return_dense, - obs_keys=obs_keys) - yield x, obs + # Move from ArrayView to numpy if backed and dense: + x = [ + xx.toarray() if isinstance(xx, anndata._core.views.ArrayView) else xx + for xx in x + ] + # Do dense conversion now so that col-wise indexing is not slow, often, dense conversion + # would be done later anyway. + if return_dense: + x = [np.asarray(xx.todense()) if isinstance(xx, scipy.sparse.spmatrix) else xx for xx in x] + is_dense = True + else: + is_dense = isinstance(x[0], np.ndarray) + # Concatenate blocks in observation dimension: + if len(x) > 1: + if is_dense: + x = np.concatenate(x, axis=0) + else: + x = scipy.sparse.vstack(x) + else: + x = x[0] + if var_idx is not None: + x = x[:, var_idx] + # Prepare .obs. + obs = pd.concat([ + adata_sliced[k].obs[obs_keys].iloc[v, :] + for k, v in idx_i_dict.items() + ], axis=0, join="inner", ignore_index=True, copy=False) + yield x, obs return generator - def adata_by_key_subset(self, idx: Union[np.ndarray, list]) -> Dict[str, anndata.AnnData]: - """ - Subsets adata_by_key as if it was one object, ie behaves the same way as self.adata[idx] without explicitly - concatenating. - """ - if idx is not None: - idx = self._validate_idx(idx) - indices_subsetted = {} - counter = 0 - for k, v in self.indices.items(): - n_obs_k = len(v) - indices_global = np.arange(counter, counter + n_obs_k) - indices_subset_k = [x for x, y in zip(v, indices_global) if y in idx] - if len(indices_subset_k) > 0: - indices_subsetted[k] = indices_subset_k - counter += n_obs_k - assert counter == self.n_obs - return dict([(k, self._adata_by_key[k][v, :]) for k, v in indices_subsetted.items()]) - else: - return self.adata_sliced - - def get_subset_idx_global(self, attr_key, values: Union[str, List[str], None] = None, - excluded_values: Union[str, List[str], None] = None) -> np.ndarray: - """ - Get indices of subset list of adata objects based on cell-wise properties treating instance as single array. - The indices are continuous across all data sets as if they were one array. +class DistributedStoreDao(DistributedStoreSingleFeatureSpace): - :param attr_key: Property to subset by. Options: + _dataset_weights: Union[None, Dict[str, float]] + _x: Union[None, dask.array.Array] + _x_by_key: Union[None, dask.array.Array] - - "assay_differentiation" points to self.assay_differentiation_obs_key - - "assay_sc" points to self.assay_sc_obs_key - - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key - - "cell_line" points to self.cell_line - - "cellontology_class" points to self.cellontology_class_obs_key - - "developmental_stage" points to self.developmental_stage_obs_key - - "ethnicity" points to self.ethnicity_obs_key - - "organ" points to self.organ_obs_key - - "organism" points to self.organism_obs_key - - "sample_source" points to self.sample_source_obs_key - - "sex" points to self.sex_obs_key - - "state_exact" points to self.state_exact_obs_key - :param values: Classes to overlap to. - :return Index vector - """ - # Get indices of of cells in target set by file. - idx_by_dataset = self.get_subset_idx(attr_key=attr_key, values=values, excluded_values=excluded_values) - # Translate file-wise indices into global index list across all data sets. - idx = [] - counter = 0 - for k, v in self.indices.items(): - for x in v: - if k in idx_by_dataset.keys() and x in idx_by_dataset[k]: - idx.append(counter) - counter += 1 - return np.asarray(idx) + def __init__(self, x_by_key, **kwargs): + super(DistributedStoreDao, self).__init__(**kwargs) + self._x = None + self._x_as_dask = True + self._x_by_key = x_by_key @property - def indices_global(self) -> dict: - """ - Increasing indices across data sets which can be concatenated into a single index vector with unique entries - for cells. + def indices(self) -> Dict[str, np.ndarray]: + return super(DistributedStoreDao, self).indices - E.g.: For two data sets of 10 cells each, the return value would be {A:[0..9], B:[10..19]}. - Note that this operates over pre-selected indices, if this store was subsetted before resulting in only the - second half B to be kept, the return value would be {A:[0..9], B:[10..14]}, where .indices would be - {A:[0..9], B:[15..19]}. + @indices.setter + def indices(self, x: Dict[str, np.ndarray]): """ - counter = 0 - indices = {} - for k, v in self.indices.items(): - indices[k] = np.arange(counter, counter + len(v)) - counter += len(v) - return indices - + Extends setter in super class by wiping .X cache. -class DistributedStoreDao(DistributedStoreBase): + Setter imposes a few constraints on indices: - def __init__(self, cache_path: Union[str, os.PathLike], columns: Union[None, List[str]] = None): + 1) checks that keys are contained ._adata_by_key.keys() + 2) checks that indices are contained in size of values of ._adata_by_key + 3) checks that indces are not duplicated + 4) checks that indices are sorted """ + self._x = None + for k, v in x.items(): + assert k in self._adata_by_key.keys(), f"did not find key {k}" + assert np.max(v) < self._adata_by_key[k].n_obs, f"found index for key {k} that exceeded data set size" + assert len(v) == len(np.unique(v)), f"found duplicated indices for key {k}" + assert np.all(np.diff(v) >= 0), f"indices not sorted for key {k}" + self._indices = x - :param cache_path: Store directory. - :param columns: Which columns to read into the obs copy in the output, see pandas.read_parquet(). + @property + def data_by_key(self): """ - # Collect all data loaders from files in directory: - adata_by_key = {} - indices = {} - for f in os.listdir(cache_path): - adata = None - trial_path = os.path.join(cache_path, f) - if os.path.isdir(trial_path): - # zarr-backed anndata are saved as directories with the elements of the array group as further sub - # directories, e.g. a directory called "X", and a file ".zgroup" which identifies the zarr group. - if [".zgroup" in os.listdir(trial_path)]: - adata = read_dao(trial_path, use_dask=True, columns=columns, obs_separate=False) - print(f"Discovered {f} as zarr group, " - f"sized {round(sys.getsizeof(adata) / np.power(1024, 2), 1)}MB") - if adata is not None: - adata_by_key[adata.uns["id"]] = adata - indices[adata.uns["id"]] = np.arange(0, adata.n_obs) - self._x_as_dask = True - super(DistributedStoreDao, self).__init__(adata_by_key=adata_by_key, indices=indices, obs_by_key=None) + Data matrix for each selected data set in store, sub-setted by selected cells. + """ + # Accesses _x_by_key rather than _adata_by_key as long as the dask arrays are stored there. + return dict([(k, self._x_by_key[k][v, :]) for k, v in self.indices.items()]) @property - def X(self) -> Union[dask.array.Array]: - assert np.all([isinstance(self._adata_by_key[k].X, dask.array.Array) for k in self.indices.keys()]) - return dask.array.vstack([ - self._adata_by_key[k].X[v, :] - for k, v in self.indices.items() - ]) + def X(self) -> dask.array.Array: + """ + One dask array of all cells. + + Requires feature dimension to be shared. + """ + if self._x is None: + if self.data_source == "X": + # TODO avoiding anndata .X here + # assert np.all([isinstance(self._adata_by_key[k].X, dask.array.Array) for k in self.indices.keys()]) + assert np.all([isinstance(self._x_by_key[k], dask.array.Array) for k in self.indices.keys()]) + self._x = dask.optimize(dask.array.vstack([ + self._x_by_key[k][v, :] + for k, v in self.indices.items() + ]))[0] + else: + raise ValueError(f"Did not recognise data_source={self.data_source}.") + return self._x @property - def obs(self) -> Union[pd.DataFrame]: + def obs(self) -> pd.DataFrame: """ Assemble .obs table of subset of selected data. @@ -627,112 +803,49 @@ def obs(self) -> Union[pd.DataFrame]: for k, v in self.indices.items() ], axis=0, join="inner", ignore_index=True, copy=False) - def n_counts(self, idx: Union[np.ndarray, list, None] = None) -> np.ndarray: - """ - Compute sum over features for each observation in index. - - :param idx: Index vector over observations in object. - :return: Array with sum per observations: (number of observations in index,) - """ - return np.asarray(self.X.sum(axis=1)).flatten() - - def generator( + def _generator( self, - idx: Union[np.ndarray, None] = None, - batch_size: int = 1, + idx_gen: iter, + var_idx: np.ndarray, obs_keys: List[str] = [], - return_dense: bool = True, - randomized_batch_access: bool = False, - random_access: bool = False, ) -> iter: """ - Yields an unbiased generator over observations in the contained data sets. + Yields data batches as defined by index sets emitted from index generator. - :param idx: Global idx to query from store. These is an array with indicies corresponding to a contiuous index - along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of - self.adata_by_key. - :param batch_size: Number of observations read from disk in each batched access (generator invocation). + :param idx_gen: Generator that yield two elements in each draw: + - np.ndarray: The cells to emit. + - List[Tuple[int, int]: Batch start and end indices. + :param var_idx: The features to emit. :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available in self.adata_by_key. - :param return_dense: Whether to force return count data .X as dense batches. This allows more efficient feature - indexing if the store is sparse (column indexing on csr matrices is slow). - :param randomized_batch_access: Whether to randomize batches during reading (in generator). Lifts necessity of - using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there - is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest - changes in batch composition. - Do not use randomized_batch_access and random_access. - :param random_access: Whether to fully shuffle observations before batched access takes place. May - slow down access compared randomized_batch_access and to no randomization. - Do not use randomized_batch_access and random_access. :return: Generator function which yields batch_size at every invocation. - The generator returns a tuple of (.X, .obs) with types: - - - if store format is h5ad: (Union[scipy.sparse.csr_matrix, np.ndarray], pandas.DataFrame) - """ - idx, var_idx = self._generator_helper(idx=idx) - if randomized_batch_access and random_access: - raise ValueError("Do not use randomized_batch_access and random_access.") + The generator returns a tuple of (.X, .obs). + """ + # Normalise cell indices such that each organism is indexed starting at zero: + # This is required below because each organism is represented as its own dask array. + # TODO this might take a lot of time as the dask array is built. + t0 = time.time() + x = self.X + print(f"init X: {time.time() - t0}") + t0 = time.time() + obs = self.obs[obs_keys] + # Redefine index so that .loc indexing can be used instead of .iloc indexing: + obs.index = np.arange(0, obs.shape[0]) + print(f"init obs: {time.time() - t0}") def generator(): - # Can treat full data set as a single array because dask keeps expression data and obs out of memory. - x = self.X[idx, :] - obs = self.obs.iloc[idx, :] - # Redefine index so that .loc indexing can be used instead of .iloc indexing: - obs.index = np.arange(0, obs.shape[0]) - n_obs = x.shape[0] - remainder = n_obs % batch_size - assert n_obs == obs.shape[0] - batch_starts_ends = [ - (int(x * batch_size), int(np.minimum((x * batch_size) + batch_size, n_obs))) - for x in np.arange(0, n_obs // batch_size + int(remainder > 0)) - ] - batch_range = np.arange(0, len(batch_starts_ends)) - if randomized_batch_access: - np.random.shuffle(batch_range) - epoch_indices = np.arange(0, n_obs) - if random_access: - np.random.shuffle(epoch_indices) - for i in batch_range: - s, e = batch_starts_ends[i] - # Feature indexing: Run in same operation as observation index so that feature chunking can be - # efficiently used if available. TODO does this make a difference in dask? - if random_access: - if var_idx is not None: - x_i = x[epoch_indices[s:e], var_idx] - else: - x_i = x[epoch_indices[s:e], :] - else: - # Use slicing because observations accessed in batch are ordered in data set: - # Note that epoch_indices[i] == i if not random_access. + # Can all data sets corresponding to one organism as a single array because they share the second dimension + # and dask keeps expression data and obs out of memory. + for idx, batch_starts_ends in idx_gen: + x_temp = x[idx, :] + obs_temp = obs.loc[obs.index[idx], :] # TODO better than iloc? + for s, e in batch_starts_ends: + x_i = x_temp[s:e, :] if var_idx is not None: - x_i = x[s:e, var_idx] - else: - x_i = x[s:e, :] - # Exploit fact that index of obs is just increasing list of integers, so we can use the .loc[] indexing - # instead of .iloc[]: - obs_i = obs[obs_keys].loc[epoch_indices[s:e].tolist(), :] - yield x_i, obs_i + x_i = x_i[:, var_idx] + # Exploit fact that index of obs is just increasing list of integers, so we can use the .loc[] + # indexing instead of .iloc[]: + obs_i = obs_temp.loc[obs_temp.index[s:e], :] + yield x_i, obs_i return generator - - -def load_store(cache_path: Union[str, os.PathLike], store_format: str = "dao", - columns: Union[None, List[str]] = None) -> Union[DistributedStoreH5ad, DistributedStoreDao]: - """ - Instantiates a distributed store class. - - :param cache_path: Store directory. - :param store_format: Format of store {"h5ad", "dao"}. - - - "h5ad": Returns instance of DistributedStoreH5ad. - - "dao": Returns instance of DistributedStoreDoa (distributed access optimized). - :param columns: Which columns to read into the obs copy in the output, see pandas.read_parquet(). - Only relevant if store_format is "dao". - :return: Instances of a distributed store class. - """ - if store_format == "h5ad": - return DistributedStoreH5ad(cache_path=cache_path) - elif store_format == "dao": - return DistributedStoreDao(cache_path=cache_path, columns=columns) - else: - raise ValueError(f"Did not recognize store_format {store_format}.") diff --git a/sfaira/data/utils_scripts/streamline_selected.py b/sfaira/data/utils_scripts/streamline_selected.py index 27c2ef0ef..6fc201978 100644 --- a/sfaira/data/utils_scripts/streamline_selected.py +++ b/sfaira/data/utils_scripts/streamline_selected.py @@ -32,7 +32,6 @@ ) ds.streamline_metadata( schema=schema.lower(), - uns_to_obs=False, clean_obs=False, clean_var=True, clean_uns=True, diff --git a/sfaira/data/utils_scripts/test_store.py b/sfaira/data/utils_scripts/test_store.py new file mode 100644 index 000000000..56249fdda --- /dev/null +++ b/sfaira/data/utils_scripts/test_store.py @@ -0,0 +1,284 @@ +import matplotlib.pyplot as plt +import numpy as np +import os +import pandas as pd +import seaborn as sb +import sfaira +import sys +import time +from typing import List + +# Set global variables. +print("sys.argv", sys.argv) + +N_DRAWS = 10 +BATCH_SIZES = [64] + +path_store_h5ad = str(sys.argv[1]) +path_store_dao = str(sys.argv[2]) +path_out = str(sys.argv[3]) + +store_type = [] +kwargs = [] +compression_kwargs = [] +if path_store_h5ad.lower() != "none": + store_type.append("h5ad") + kwargs.append({"dense": False}) + compression_kwargs.append({}) + +store_type.append("dao") +kwargs.append({"dense": True, "chunks": 128}) +compression_kwargs.append({"compressor": "default", "overwrite": True, "order": "C"}) + +time_measurements_initiate = {} +memory_measurements_initiate = {} +time_measurements = { + "load_sequential_from_one_dataset": {}, + "load_sequential_from_many_datasets": {}, + "load_random_from_one_dataset": {}, + "load_random_from_many_datasets": {}, + "load_sequential_from_one_dataset_todense": {}, + "load_sequential_from_many_datasets_todense": {}, + "load_random_from_one_dataset_todense": {}, + "load_random_from_many_datasets_todense": {}, + "load_sequential_from_one_dataset_todense_varsubet": {}, + "load_sequential_from_many_datasets_todense_varsubet": {}, + "load_random_from_one_dataset_todense_varsubet": {}, + "load_random_from_many_datasets_todense_varsubet": {}, +} + + +def time_gen(_store, store_format, kwargs) -> List[float]: + """ + Take samples from generator and measure time taken to generate each sample. + """ + if store_format == "h5ad": + del kwargs["random_access"] + if kwargs["var_subset"]: + gc = sfaira.versions.genomes.genomes.GenomeContainer(assembly="Homo_sapiens.GRCh38.102") + gc.subset(symbols=["VTA1", "MLXIPL", "BAZ1B", "RANBP9", "PPARGC1A", "DDX25", "CRYAB"]) + _store.genome_containers = gc + del kwargs["var_subset"] + _gen = _store.generator(**kwargs)() + _measurements = [] + for _ in range(N_DRAWS): + _t0 = time.time() + _ = next(_gen) + _measurements.append(time.time() - _t0) + return _measurements + + +def get_idx_dataset_start(_store, k_target): + idx = {} + counter = 0 + for k, v in _store.indices.items(): + if k in k_target: + idx[k] = counter + counter += len(v) + return [idx[k] for k in k_target] + + +# Define data objects to be comparable: +store = sfaira.data.load_store(cache_path=path_store_dao, store_format="dao") +store.subset(attr_key="organism", values="human") +store = store.stores["human"] +k_datasets_dao = list(store.indices.keys()) +# Sort by size: +k_datasets_dao = np.asarray(k_datasets_dao)[np.argsort([len(v) for v in store.indices.values()])].tolist() +store = sfaira.data.load_store(cache_path=path_store_h5ad, store_format="h5ad") +store.subset(attr_key="organism", values="human") +store = store.stores["human"] +k_datasets_h5ad = list(store.indices.keys()) +# Only retain intersection of data sets while keeping order. +k_datasets = [x for x in k_datasets_dao if x in k_datasets_h5ad] +n_datasets = len(k_datasets) +print(f"running benchmark on {n_datasets} data sets.") +for store_type_i, kwargs_i, compression_kwargs_i in zip(store_type, kwargs, compression_kwargs): + path_store = path_store_h5ad if store_type_i == "h5ad" else path_store_dao + + # Measure initiate time. + time_measurements_initiate[store_type_i] = [] + memory_measurements_initiate[store_type_i] = [] + for _ in range(3): + t0 = time.time() + store = sfaira.data.load_store(cache_path=path_store, store_format=store_type_i) + # Include initialisation of generator in timing to time overhead generated here. + _ = store.generator() + time_measurements_initiate[store_type_i].append(time.time() - t0) + memory_measurements_initiate[store_type_i].append(np.sum(list(store.adata_memory_footprint.values()))) + + time_measurements["load_sequential_from_one_dataset"][store_type_i] = {} + time_measurements["load_sequential_from_many_datasets"][store_type_i] = {} + time_measurements["load_random_from_one_dataset"][store_type_i] = {} + time_measurements["load_random_from_many_datasets"][store_type_i] = {} + time_measurements["load_sequential_from_one_dataset_todense"][store_type_i] = {} + time_measurements["load_sequential_from_many_datasets_todense"][store_type_i] = {} + time_measurements["load_random_from_one_dataset_todense"][store_type_i] = {} + time_measurements["load_random_from_many_datasets_todense"][store_type_i] = {} + time_measurements["load_sequential_from_one_dataset_todense_varsubet"][store_type_i] = {} + time_measurements["load_sequential_from_many_datasets_todense_varsubet"][store_type_i] = {} + time_measurements["load_random_from_one_dataset_todense_varsubet"][store_type_i] = {} + time_measurements["load_random_from_many_datasets_todense_varsubet"][store_type_i] = {} + store = sfaira.data.load_store(cache_path=path_store, store_format=store_type_i) + store.subset(attr_key="organism", values="human") + store = store.stores["human"] + idx_dataset_start = get_idx_dataset_start(_store=store, k_target=k_datasets) + idx_dataset_end = [i + len(store.indices[x]) for i, x in zip(idx_dataset_start, k_datasets)] + for bs in BATCH_SIZES: + key_bs = "bs" + str(bs) + print(key_bs) + + # Measure load_sequential_from_one_dataset time. + scenario = "load_sequential_from_one_dataset" + print(scenario) + for dense_varsubset in [(False, False), (True, False), (True, True)]: + dense, varsubset = dense_varsubset + suffix = "_todense_varsubet" if dense and varsubset else "_todense" if dense and not varsubset else "" + kwargs = { + "idx": np.concatenate([ + np.arange(idx_dataset_start[0] + bs * i, idx_dataset_start[0] + bs * (i + 1)) + for i in range(N_DRAWS)]), + "batch_size": bs, + "return_dense": dense, + "randomized_batch_access": False, + "random_access": False, + "var_subset": varsubset, + } + time_measurements[scenario + suffix][store_type_i][key_bs] = time_gen( + _store=store, store_format=store_type_i, kwargs=kwargs) + + # Measure load_random_from_one_dataset time. + scenario = "load_random_from_one_dataset" + print(scenario) + for dense_varsubset in [(False, False), (True, False), (True, True)]: + dense, varsubset = dense_varsubset + suffix = "_todense_varsubet" if dense and varsubset else "_todense" if dense and not varsubset else "" + kwargs = { + "idx": np.random.choice( + np.arange(idx_dataset_start[0], np.maximum(idx_dataset_end[0], idx_dataset_start[0] + bs * N_DRAWS)), + size=bs * N_DRAWS, replace=False), + "batch_size": bs, + "return_dense": dense, + "randomized_batch_access": False, + "random_access": False, + "var_subset": varsubset, + } + time_measurements[scenario + suffix][store_type_i][key_bs] = time_gen( + _store=store, store_format=store_type_i, kwargs=kwargs) + + # Measure load_sequential_from_many_datasets time. + scenario = "load_sequential_from_many_datasets" + print(scenario) + for dense_varsubset in [(False, False), (True, False), (True, True)]: + dense, varsubset = dense_varsubset + suffix = "_todense_varsubet" if dense and varsubset else "_todense" if dense and not varsubset else "" + kwargs = { + "idx": np.concatenate([np.arange(s, s + bs) for s in idx_dataset_start]), + "batch_size": bs, + "return_dense": dense, + "randomized_batch_access": False, + "random_access": False, + "var_subset": varsubset, + } + time_measurements[scenario + suffix][store_type_i][key_bs] = time_gen( + _store=store, store_format=store_type_i, kwargs=kwargs) + + # Measure load_random_from_many_datasets time. + scenario = "load_random_from_many_datasets" + print(scenario) + for dense_varsubset in [(False, False), (True, False), (True, True)]: + dense, varsubset = dense_varsubset + suffix = "_todense_varsubet" if dense and varsubset else "_todense" if dense and not varsubset else "" + kwargs = { + "idx": np.concatenate([ + np.random.choice(np.arange(s, np.maximum(e, s + bs)), size=bs, replace=False) + for s, e in zip(idx_dataset_start, idx_dataset_end)]), + "batch_size": bs, + "return_dense": dense, + "randomized_batch_access": False, + "random_access": True, + "var_subset": varsubset, + } + time_measurements[scenario + suffix][store_type_i][key_bs] = time_gen( + _store=store, store_format=store_type_i, kwargs=kwargs) + +ncols = 2 +fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(14, 12)) +for i, x in enumerate([ + [ + "initialisation time", + ], + [ + "initialisation memory", + ], + [ + "load_sequential_from_one_dataset", + "load_sequential_from_one_dataset_todense", + "load_sequential_from_one_dataset_todense_varsubet", + ], + [ + "load_sequential_from_many_datasets", + "load_sequential_from_many_datasets_todense", + "load_sequential_from_many_datasets_todense_varsubet", + ], + [ + "load_random_from_one_dataset", + "load_random_from_one_dataset_todense", + "load_random_from_one_dataset_todense_varsubet", + ], + [ + "load_random_from_many_datasets", + "load_random_from_many_datasets_todense", + "load_random_from_many_datasets_todense_varsubet", + ], +]): + if i == 0 or i == 1: + if i == 0: + measurements_initiate = time_measurements_initiate + ylabel = "log10 time sec" + log = True + else: + measurements_initiate = memory_measurements_initiate + ylabel = "memory MB" + log = False + df_sb = pd.concat([ + pd.DataFrame({ + ylabel: np.log(measurements_initiate[m]) / np.log(10) if log else measurements_initiate[m], + "store": m, + "draw": range(len(measurements_initiate[m])), + }) + for m in measurements_initiate.keys() + ], axis=0) + sb.boxplot( + data=df_sb, + x="store", y=ylabel, + ax=axs[i // ncols, i % ncols] + ) + axs[i // ncols, i % ncols].set_title(x) + elif len(x) > 0: + df_sb = pd.concat([ + pd.concat([ + pd.concat([ + pd.DataFrame({ + "log10 time sec": np.log(time_measurements[m][n][o]) / np.log(10), + "scenario": " ".join(m.split("_")[4:]), + "store": n, + "batch size": o, + "scenario - batch size": " ".join(m.split("_")[4:]) + "_bs" + str(o), + "draw": range(len(time_measurements[m][n][o])), + }) + for o in time_measurements[m][n].keys() + ], axis=0) + for n in time_measurements[m].keys() + ], axis=0) + for m in x + ], axis=0) + # Could collapse draws to mean and put batch size on x. + sb.lineplot( + data=df_sb, + x="draw", y="log10 time sec", hue="scenario - batch size", style="store", + ax=axs[i // ncols, i % ncols] + ) + axs[i // ncols, i % ncols].set_title(x[0]) +plt.tight_layout() +plt.savefig(os.path.join(path_out, "data_store_benchmark.pdf")) diff --git a/sfaira/data/utils_scripts/test_streamlined.sh b/sfaira/data/utils_scripts/test_streamlined.sh new file mode 100644 index 000000000..931454ad8 --- /dev/null +++ b/sfaira/data/utils_scripts/test_streamlined.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +CODE_PATH="/home/icb/${USER}/git" +OUT_PATH="/storage/groups/ml01/workspace/david.fischer/sfaira/cellxgene/processed_data/" +OUT_FN="${OUT_PATH}validation_summary.txt" +SCHEMA="cellxgene" +DOIS="10.1016/j.cell.2019.08.008,10.1016/j.celrep.2018.11.086,10.1016/j.cmet.2016.08.020,10.1016/j.neuron.2019.06.011,10.1038/s41422-018-0099-2,10.1038/s41467-018-06318-7,10.1038/s41590-020-0602-z,10.1084/jem.20191130,10.1101/2020.03.13.991455,10.1101/2020.10.12.335331,10.1126/science.aay3224,10.1126/science.aba7721,10.15252/embj.2018100811,no_doi_10x_genomics" + +source "/home/${USER}/.bashrc" +echo "Summary of exports of data sets ${DOIS}" > OUT_FN +for doi in "${DOIS[@]}"; do + echo "Summary of exports of data set ${doi}" >> OUT_FN + cellxgene schema validate ${OUT_PATH}${doi}/ >> OUT_FN +done + +CODE_PATH="/home/icb/${USER}/git" +OUT_PATH="/storage/groups/ml01/workspace/david.fischer/sfaira/cellxgene/processed_data/" +OUT_FN="${OUT_PATH}validation_summary.txt" +SCHEMA="cellxgene" +DOIS="10.1016/j.cell.2017.09.004,10.1016/j.cell.2018.02.001,10.1016/j.cell.2018.08.067,10.1016/j.cell.2019.06.029,10.1016/j.cels.2016.08.011,10.1016/j.cmet.2019.01.021,10.1016/j.devcel.2020.01.033,10.1038/nmeth.4407,10.1038/s41467-019-10861-2,10.1038/s41467-019-12464-3,10.1038/s41467-019-12780-8,10.1038/s41586-018-0698-6,10.1038/s41586-019-1373-2,10.1038/s41586-019-1631-3,10.1038/s41586-019-1652-y,10.1038/s41586-019-1654-9,10.1038/s41586-020-2157-4,10.1038/s41586-020-2922-4,10.1038/s41591-019-0468-5,10.1038/s41593-019-0393-4,10.1038/s41597-019-0351-8,10.1073/pnas.1914143116,10.1101/661728,10.1101/753806,10.1126/science.aat5031,10.1186/s13059-019-1906-x,no_doi_regev" + +source "/home/${USER}/.bashrc" +python ${CODE_PATH}/sfaira/sfaira/data/utils_scripts/streamline_selected.py ${DATA_PATH} ${META_PATH} ${CACHE_PATH} ${OUT_PATH} ${SCHEMA} ${DOIS} diff --git a/sfaira/data/utils_scripts/write_store.py b/sfaira/data/utils_scripts/write_store.py index 9b3f319f5..42cd7f90b 100644 --- a/sfaira/data/utils_scripts/write_store.py +++ b/sfaira/data/utils_scripts/write_store.py @@ -45,8 +45,7 @@ match_to_reference={"human": "Homo_sapiens.GRCh38.102", "mouse": "Mus_musculus.GRCm38.102"}, subset_genes_to_type="protein_coding" ) - ds.streamline_metadata(schema="sfaira", uns_to_obs=True, clean_obs=True, clean_var=True, clean_uns=True, - clean_obs_names=True) + ds.streamline_metadata(schema="sfaira", clean_obs=True, clean_var=True, clean_uns=True, clean_obs_names=True) ds.write_distributed_store(dir_cache=path_store, store_format=store_type, compression_kwargs=compression_kwargs, **kwargs) ds.clear() diff --git a/sfaira/estimators/keras.py b/sfaira/estimators/keras.py index 2ba21fbf5..9452d5fbe 100644 --- a/sfaira/estimators/keras.py +++ b/sfaira/estimators/keras.py @@ -13,7 +13,7 @@ from tqdm import tqdm from sfaira.consts import AdataIdsSfaira, OCS, AdataIds -from sfaira.data import DistributedStoreBase +from sfaira.data.store.single_store import DistributedStoreSingleFeatureSpace from sfaira.models import BasicModelKeras from sfaira.versions.metadata import CelltypeUniverse, OntologyCl, OntologyObo from sfaira.versions.topologies import TopologyContainer @@ -40,7 +40,7 @@ class EstimatorKeras: """ Estimator base class for keras models. """ - data: Union[anndata.AnnData, DistributedStoreBase] + data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace] model: Union[BasicModelKeras, None] topology_container: Union[TopologyContainer, None] model_id: Union[str, None] @@ -55,7 +55,7 @@ class EstimatorKeras: def __init__( self, - data: Union[anndata.AnnData, np.ndarray, DistributedStoreBase], + data: Union[anndata.AnnData, np.ndarray, DistributedStoreSingleFeatureSpace], model_dir: Union[str, None], model_class: str, model_id: Union[str, None], @@ -71,7 +71,7 @@ def __init__( self.model_class = model_class self.topology_container = model_topology # Prepare store with genome container sub-setting: - if isinstance(self.data, DistributedStoreBase): + if isinstance(self.data, DistributedStoreSingleFeatureSpace): self.data.genome_container = self.topology_container.gc self.history = None @@ -159,7 +159,7 @@ def load_weights_from_cache(self): def init_model(self, clear_weight_cache=True, override_hyperpar=None): """ - instantiate the model + Instantiate the model. :return: """ if clear_weight_cache: @@ -236,23 +236,20 @@ def _prepare_data_matrix(self, idx: Union[np.ndarray, None]) -> scipy.sparse.csr x = x[idx, :] # If the feature space is already mapped to the right reference, return the data matrix immediately - if self._adata_ids.mapped_features in self.data.uns_keys() and \ - self.data.uns[self._adata_ids.mapped_features] == self.topology_container.gc.assembly: - print(f"found {x.shape[0]} observations") - return x - - # Compute indices of genes to keep - data_ids = self.data.var[self._adata_ids.gene_id_ensembl].values.tolist() - target_ids = self.topology_container.gc.ensembl - idx_map = np.array([data_ids.index(z) for z in target_ids]) - # Assert that each ID from target IDs appears exactly once in data IDs: - assert np.all([z in data_ids for z in target_ids]), "not all target feature IDs found in data" - assert np.all([np.sum(z == np.array(data_ids)) <= 1. for z in target_ids]), \ - "duplicated target feature IDs exist in data" - # Map feature space. - x = x[:, idx_map] - print(f"found {len(idx_map)} intersecting features between {x.shape[1]} features in input data set and" - f" {self.topology_container.n_var} features in reference genome") + if self.data.n_vars != self.topology_container.n_var or \ + not np.all(self.data.var[self._adata_ids.gene_id_ensembl] == self.topology_container.gc.ensembl): + # Compute indices of genes to keep + data_ids = self.data.var[self._adata_ids.gene_id_ensembl].values.tolist() + target_ids = self.topology_container.gc.ensembl + idx_map = np.array([data_ids.index(z) for z in target_ids]) + # Assert that each ID from target IDs appears exactly once in data IDs: + assert np.all([z in data_ids for z in target_ids]), "not all target feature IDs found in data" + assert np.all([np.sum(z == np.array(data_ids)) <= 1. for z in target_ids]), \ + "duplicated target feature IDs exist in data" + # Map feature space. + x = x[:, idx_map] + print(f"found {len(idx_map)} intersecting features between {x.shape[1]} features in input data set and" + f" {self.topology_container.n_var} features in reference genome") print(f"found {x.shape[0]} observations") return x @@ -293,18 +290,25 @@ def split_train_val_test(self, val_split: float, test_split: Union[float, dict]) if k not in self.data.obs.columns: raise ValueError(f"Did not find column {k} used to define test set in self.data.") in_test = np.logical_and(in_test, np.array([x in v for x in self.data.obs[k].values])) - elif isinstance(self.data, DistributedStoreBase): - idx = self.data.get_subset_idx_global(attr_key=k, values=v) + elif isinstance(self.data, DistributedStoreSingleFeatureSpace): + idx = self.data.get_subset_idx(attr_key=k, values=v, excluded_values=None) + # Build continuous vector across all sliced data sets and establish which observations are kept + # in subset. in_test_k = np.ones((self.data.n_obs,), dtype=int) == 0 - in_test_k[idx] = True + counter = 0 + for kk, vv in self.data.indices.items(): + if kk in idx.keys() and len(idx[kk]) > 0: + in_test_k[np.where([x in idx[kk] for x in vv])[0] + counter] = True + counter += len(vv) in_test = np.logical_and(in_test, in_test_k) else: assert False self.idx_test = np.sort(np.where(in_test)[0]) - print(f"Found {len(self.idx_test)} out of {self.data.n_obs} cells that correspond to held out data set") - print(self.idx_test) else: raise ValueError("type of test_split %s not recognized" % type(test_split)) + print(f"Found {len(self.idx_test)} out of {self.data.n_obs} cells that correspond to test data set") + assert len(self.idx_test) < self.data.n_obs, "test set covers full data set, apply a more restrictive test " \ + "data definiton" idx_train_eval = np.array([x for x in all_idx if x not in self.idx_test]) np.random.seed(1) self.idx_eval = np.sort(np.random.choice( @@ -491,7 +495,7 @@ def train( @property def using_store(self) -> bool: - return isinstance(self.data, DistributedStoreBase) + return isinstance(self.data, DistributedStoreSingleFeatureSpace) @property def obs_train(self): @@ -513,7 +517,7 @@ class EstimatorKerasEmbedding(EstimatorKeras): def __init__( self, - data: Union[anndata.AnnData, np.ndarray, DistributedStoreBase], + data: Union[anndata.AnnData, np.ndarray, DistributedStoreSingleFeatureSpace], model_dir: Union[str, None], model_id: Union[str, None], model_topology: TopologyContainer, @@ -601,7 +605,7 @@ def _get_base_generator( # Prepare data reading according to whether anndata is backed or not: if self.using_store: - generator_raw = self.data.generator( + generator_raw, _ = self.data.generator( idx=idx, batch_size=batch_size, obs_keys=[], @@ -999,7 +1003,7 @@ class EstimatorKerasCelltype(EstimatorKeras): def __init__( self, - data: Union[anndata.AnnData, DistributedStoreBase], + data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace], model_dir: Union[str, None], model_id: Union[str, None], model_topology: TopologyContainer, @@ -1022,7 +1026,7 @@ def __init__( ) if remove_unlabeled_cells: # Remove cells without type label from store: - if isinstance(self.data, DistributedStoreBase): + if isinstance(self.data, DistributedStoreSingleFeatureSpace): self.data.subset(attr_key="cellontology_class", excluded_values=[ self._adata_ids.unknown_celltype_identifier, self._adata_ids.not_a_cell_celltype_identifier, @@ -1161,6 +1165,7 @@ def _get_base_generator( weighted: bool, batch_size: int, randomized_batch_access: bool, + **kwargs, ): """ Yield a basic generator based on which a tf dataset can be built. @@ -1188,7 +1193,7 @@ def _get_base_generator( if self.using_store: if weighted: raise ValueError("using weights with store is not supported yet") - generator_raw = self.data.generator( + generator_raw, _ = self.data.generator( idx=idx, batch_size=batch_size, obs_keys=[self._adata_ids.cellontology_id], diff --git a/sfaira/estimators/metrics.py b/sfaira/estimators/metrics.py index f075b9da5..8d26fe29c 100644 --- a/sfaira/estimators/metrics.py +++ b/sfaira/estimators/metrics.py @@ -12,6 +12,15 @@ def custom_mse(y_true, y_pred, sample_weight=None): return se_red +def custom_mean_squared_logp1_error(y_true, y_pred, sample_weight=None): + y_pred = tf.split(y_pred, num_or_size_splits=2, axis=1)[0] + y_true = tf.math.log(y_true + 1.) + y_pred = tf.math.log(y_pred + 1.) + se = tf.square(tf.subtract(y_true, y_pred)) + se_red = tf.reduce_mean(se) + return se_red + + def custom_negll_nb(y_true, y_pred, sample_weight=None): x = y_true loc, scale = tf.split(y_pred, num_or_size_splits=2, axis=1) diff --git a/sfaira/train/summaries.py b/sfaira/train/summaries.py index 933ba3924..8d567c614 100644 --- a/sfaira/train/summaries.py +++ b/sfaira/train/summaries.py @@ -1370,13 +1370,15 @@ def plot_best( def get_gradients_by_celltype( self, - organ: Union[str, None], + model_organ: str, + data_organ: str, organism: Union[str, None], genome: Union[str, None, dict], model_type: Union[str, List[str]], metric_select: str, data_source: str, datapath, + gene_type: str = "protein_coding", configpath: Union[None, str] = None, store_format: Union[None, str] = None, test_data=True, @@ -1387,7 +1389,8 @@ def get_gradients_by_celltype( """ Compute gradients across latent units with respect to input features for each cell type. - :param organ: + :param model_organ: + :param data_organ: :param organism: :param model_type: :param metric_select: @@ -1403,7 +1406,7 @@ def get_gradients_by_celltype( metric_select=metric_select, partition_select=partition_select, subset={ - "organ": organ, + "organ": model_organ, "model_type": model_type, } ) @@ -1439,10 +1442,10 @@ def get_gradients_by_celltype( u = Universe(data_path=datapath) if organism is not None: u.subset("organism", organism) - if organ is not None: - u.subset("organ", organ) + if data_organ is not None: + u.subset("organ", data_organ) u.load(allow_caching=False) - u.streamline_features(match_to_reference=genome) + u.streamline_features(match_to_reference=genome, subset_genes_to_type=gene_type) u.streamline_metadata() adata = u.adata else: @@ -1482,7 +1485,8 @@ def get_gradients_by_celltype( def plot_gradient_distr( self, - organ: str, + model_organ: str, + data_organ: str, model_type: Union[str, List[str]], metric_select: str, datapath: str, @@ -1492,6 +1496,7 @@ def plot_gradient_distr( configpath: Union[None, str] = None, store_format: Union[None, str] = None, test_data=True, + gene_type: str = "protein_coding", partition_select: str = "val", normalize=True, remove_inactive=True, @@ -1521,11 +1526,13 @@ def plot_gradient_distr( celltypes = {} for modelt in model_type: avg_grads[modelt], celltypes[modelt] = self.get_gradients_by_celltype( - organ=organ, + model_organ=model_organ, + data_organ=data_organ, organism=organism, model_type=modelt, metric_select=metric_select, genome=genome, + gene_type=gene_type, data_source=data_source, datapath=datapath, configpath=configpath, @@ -1584,7 +1591,8 @@ def plot_gradient_distr( def plot_gradient_cor( self, - organ: str, + model_organ: str, + data_organ: str, model_type: Union[str, List[str]], metric_select: str, datapath: str, @@ -1594,6 +1602,7 @@ def plot_gradient_cor( configpath: Union[None, str] = None, store_format: Union[None, str] = None, test_data=True, + gene_type: str = "protein_coding", partition_select: str = "val", height_fig=7, width_fig=7, @@ -1607,7 +1616,8 @@ def plot_gradient_cor( """ Plot correlation heatmap of gradient vectors accumulated on input features between cell types or models. - :param organ: + :param model_organ: + :param data_organ: :param model_type: :param metric_select: :param datapath: @@ -1641,11 +1651,13 @@ def plot_gradient_cor( celltypes = {} for modelt in model_type: avg_grads[modelt], celltypes[modelt] = self.get_gradients_by_celltype( - organ=organ, + model_organ=model_organ, + data_organ=data_organ, organism=organism, model_type=modelt, metric_select=metric_select, genome=genome, + gene_type=gene_type, data_source=data_source, datapath=datapath, configpath=configpath, diff --git a/sfaira/train/train_model.py b/sfaira/train/train_model.py index dc72b8ca9..4ed5aede9 100644 --- a/sfaira/train/train_model.py +++ b/sfaira/train/train_model.py @@ -6,19 +6,19 @@ from typing import Union from sfaira.consts import AdataIdsSfaira -from sfaira.data import DistributedStoreBase, Universe +from sfaira.data import DistributedStoreSingleFeatureSpace, Universe from sfaira.estimators import EstimatorKeras, EstimatorKerasCelltype, EstimatorKerasEmbedding from sfaira.ui import ModelZoo class TrainModel: - data: Union[anndata.AnnData, DistributedStoreBase] + data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace] estimator: Union[EstimatorKeras, None] def __init__( self, - data: Union[str, anndata.AnnData, Universe, DistributedStoreBase], + data: Union[str, anndata.AnnData, Universe, DistributedStoreSingleFeatureSpace], ): # Check if handling backed anndata or base path to directory of raw files: if isinstance(data, str) and data.split(".")[-1] == "h5ad": @@ -30,7 +30,7 @@ def __init__( self.data = data elif isinstance(data, Universe): self.data = data.adata - elif isinstance(data, DistributedStoreBase): + elif isinstance(data, DistributedStoreSingleFeatureSpace): self.data = data else: raise ValueError(f"did not recongize data of type {type(data)}") @@ -42,7 +42,7 @@ def load_into_memory(self): Loads backed objects from DistributedStoreBase into single adata object in memory in .data slot. :return: """ - if isinstance(self.data, DistributedStoreBase): + if isinstance(self.data, DistributedStoreSingleFeatureSpace): adata = None for k, v in self.data.indices.items(): x = self.data.adata_by_key[k][v, :].to_memory() @@ -93,14 +93,9 @@ def save( self._save_specific(fn=fn) def n_counts(self, idx): - if isinstance(self.estimator.data, anndata.AnnData): - return np.asarray( - self.estimator.data.X[np.sort(idx), :].sum(axis=1)[np.argsort(idx)] - ).flatten() - elif isinstance(self.estimator.data, DistributedStoreBase): - return self.estimator.data.n_counts(idx=idx) - else: - assert False + return np.asarray( + self.estimator.data.X[np.sort(idx), :].sum(axis=1)[np.argsort(idx)] + ).flatten() class TrainModelEmbedding(TrainModel): @@ -110,7 +105,7 @@ class TrainModelEmbedding(TrainModel): def __init__( self, model_path: str, - data: Union[str, anndata.AnnData, Universe, DistributedStoreBase], + data: Union[str, anndata.AnnData, Universe, DistributedStoreSingleFeatureSpace], ): super(TrainModelEmbedding, self).__init__(data=data) self.estimator = None @@ -173,7 +168,7 @@ class TrainModelCelltype(TrainModel): def __init__( self, model_path: str, - data: Union[str, anndata.AnnData, Universe, DistributedStoreBase], + data: Union[str, anndata.AnnData, Universe, DistributedStoreSingleFeatureSpace], fn_target_universe: str, ): super(TrainModelCelltype, self).__init__(data=data) diff --git a/sfaira/ui/model_zoo.py b/sfaira/ui/model_zoo.py index 8b0eb0faf..a6c757db2 100644 --- a/sfaira/ui/model_zoo.py +++ b/sfaira/ui/model_zoo.py @@ -1,4 +1,3 @@ -import abc import numpy as np import pandas as pd from typing import List, Union @@ -8,16 +7,20 @@ from sfaira.versions.topologies import TopologyContainer, TOPOLOGIES -class ModelZoo(abc.ABC): +class ModelZoo: + """ - Model zoo base class. + Model zoo class. """ - topology_container: TopologyContainer - zoo: Union[dict, None] + _model_id: Union[str, None] - celltypes: Union[CelltypeUniverse, None] available_model_ids: Union[list, None] + celltypes: Union[CelltypeUniverse, None] topology_container: Union[None, TopologyContainer] + zoo: Union[dict, None] + + TOPOLOGIES = TOPOLOGIES + TOPOLOGY_CONTAINER_CLASS = TopologyContainer def __init__( self, @@ -149,8 +152,8 @@ def model_id(self, x: str): f"{x} not found in available_model_ids, please check available models using ModelZoo.available_model_ids" assert len(x.split('_')) == 3, f'model_id {x} is invalid' self._model_id = x - self.topology_container = TopologyContainer( - topology=TOPOLOGIES[self.model_organism][self.model_class][self.model_type][self.model_topology], + self.topology_container = self.TOPOLOGY_CONTAINER_CLASS( + topology=self.TOPOLOGIES[self.model_organism][self.model_class][self.model_type][self.model_topology], topology_id=self.model_version ) diff --git a/sfaira/ui/user_interface.py b/sfaira/ui/user_interface.py index 96a831c7b..6bde042a7 100644 --- a/sfaira/ui/user_interface.py +++ b/sfaira/ui/user_interface.py @@ -8,7 +8,7 @@ import warnings import time -from sfaira.consts import AdataIdsSfaira, AdataIds +from sfaira.consts import AdataIdsSfaira, AdataIds, OCS from sfaira.data import DatasetInteractive from sfaira.estimators import EstimatorKerasEmbedding, EstimatorKerasCelltype from sfaira.ui.model_zoo import ModelZoo @@ -354,6 +354,11 @@ def load_data( :param obs_key_celltypes: .obs column name which contains cell type labels. :param class_maps: Cell type class maps. """ + if self.zoo_embedding.model_organism is not None and self.zoo_celltype.model_organism is not None: + assert self.zoo_embedding.model_organism == self.zoo_celltype.model_organism, \ + "Model ids set for embedding and celltype model need to correspond to the same organism" + assert self.zoo_embedding.model_organ == self.zoo_celltype.model_organ, \ + "Model ids set for embedding and celltype model need to correspond to the same organ" if self.zoo_embedding.model_organism is not None: organism = self.zoo_embedding.model_organism organ = self.zoo_embedding.model_organ @@ -366,6 +371,12 @@ def load_data( if gene_ens_col is None and gene_symbol_col is None: raise ValueError("Please provide either the gene_ens_col or the gene_symbol_col argument.") + # handle organ names with stripped spaces + if organ not in OCS.organ.node_names: + organ_dict = {i.replace(" ", ""): i for i in OCS.organ.node_names} + assert organ in organ_dict, f"Organ {organ} is not a valid nodename in the UBERON organ ontology" + organ = {i.replace(" ", ""): i for i in OCS.organ.node_names}[organ] + self.data = DatasetInteractive( data=data, organism=organism, diff --git a/sfaira/unit_tests/__init__.py b/sfaira/unit_tests/__init__.py index e69de29bb..985d04c12 100644 --- a/sfaira/unit_tests/__init__.py +++ b/sfaira/unit_tests/__init__.py @@ -0,0 +1 @@ +from .directories import DIR_TEMP diff --git a/sfaira/unit_tests/data/test_clean_celltype_maps.py b/sfaira/unit_tests/data/test_clean_celltype_maps.py deleted file mode 100644 index 4ce259e08..000000000 --- a/sfaira/unit_tests/data/test_clean_celltype_maps.py +++ /dev/null @@ -1,13 +0,0 @@ -from sfaira.data.dataloaders.loaders import DatasetSuperGroupLoaders - - -def test_map_celltype_to_ontology(): - # Paths do not matter here as data sets are not loaded for these operations. - dsgl = DatasetSuperGroupLoaders( - data_path="~", - meta_path="~", - cache_path="~" - ) - for x in dsgl.dataset_groups: - print(x.ids) - x.clean_ontology_class_map() diff --git a/sfaira/unit_tests/data/__init__.py b/sfaira/unit_tests/data_for_tests/__init__.py similarity index 100% rename from sfaira/unit_tests/data/__init__.py rename to sfaira/unit_tests/data_for_tests/__init__.py diff --git a/sfaira/unit_tests/data_for_tests/databases/__init__.py b/sfaira/unit_tests/data_for_tests/databases/__init__.py new file mode 100644 index 000000000..6717d175e --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/databases/__init__.py @@ -0,0 +1 @@ +from .utils import prepare_dsg_database diff --git a/sfaira/unit_tests/data_for_tests/databases/consts.py b/sfaira/unit_tests/data_for_tests/databases/consts.py new file mode 100644 index 000000000..9786ea673 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/databases/consts.py @@ -0,0 +1,2 @@ +CELLXGENE_COLLECTION_ID = "558385a4-b7b7-4eca-af0c-9e54d010e8dc" +CELLXGENE_DATASET_ID = "774c18c5-efa1-4dc5-9e5e-2c824bab2e34" diff --git a/sfaira/unit_tests/data_for_tests/databases/utils.py b/sfaira/unit_tests/data_for_tests/databases/utils.py new file mode 100644 index 000000000..b8cb00a53 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/databases/utils.py @@ -0,0 +1,28 @@ +import os +import pathlib + +from sfaira.data import DatasetSuperGroup +from sfaira.data.dataloaders.databases import DatasetSuperGroupDatabases +from sfaira.unit_tests.data_for_tests.databases.consts import CELLXGENE_COLLECTION_ID + +from sfaira.unit_tests.directories import DIR_DATA_DATABASES_CACHE + + +def prepare_dsg_database(database: str, download: bool = True) -> DatasetSuperGroup: + """ + Prepares data set super group of data base returns instance. + + :param database: Database to make available. + :param download: Whether to make sure that raw files are downloaded. + """ + if not os.path.exists(DIR_DATA_DATABASES_CACHE): + pathlib.Path(DIR_DATA_DATABASES_CACHE).mkdir(parents=True, exist_ok=True) + if database == "cellxgene": + dsg = DatasetSuperGroupDatabases(data_path=DIR_DATA_DATABASES_CACHE) + # Only retain pre-defined target collections to avoid bulk downloads during unit tests. + dsg.subset(key="collection_id", values=CELLXGENE_COLLECTION_ID) + else: + assert False, database + if download: + dsg.download() + return dsg diff --git a/sfaira/unit_tests/data_for_tests/loaders/__init__.py b/sfaira/unit_tests/data_for_tests/loaders/__init__.py new file mode 100644 index 000000000..f2096a71b --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/__init__.py @@ -0,0 +1,3 @@ +from .consts import ASSEMBLY_HUMAN, ASSEMBLY_MOUSE +from .loaders import DatasetSuperGroupMock +from .utils import prepare_dsg, prepare_store diff --git a/sfaira/unit_tests/data_for_tests/loaders/consts.py b/sfaira/unit_tests/data_for_tests/loaders/consts.py new file mode 100644 index 000000000..549523759 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/consts.py @@ -0,0 +1,5 @@ +ASSEMBLY_HUMAN = "Homo_sapiens.GRCh38.102" +ASSEMBLY_MOUSE = "Mus_musculus.GRCm38.102" + +CELLTYPES = ["adventitial cell", "endothelial cell", "acinar cell", "pancreatic PP cell", "type B pancreatic cell"] +CL_VERSION = "v2021-02-01_cl" diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/__init__.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/__init__.py new file mode 100644 index 000000000..2e97bdc09 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/__init__.py @@ -0,0 +1 @@ +from .super_group import DatasetSuperGroupMock diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/__init__.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/__init__.py new file mode 100644 index 000000000..b1d5b2c2b --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/__init__.py @@ -0,0 +1 @@ +FILE_PATH = __file__ diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.py new file mode 100644 index 000000000..92425a428 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.py @@ -0,0 +1,12 @@ +import anndata + +from sfaira.unit_tests.data_for_tests.loaders.consts import ASSEMBLY_HUMAN, CELLTYPES +from sfaira.unit_tests.data_for_tests.loaders.utils import _create_adata + + +def load(data_dir, sample_fn, **kwargs) -> anndata.AnnData: + ncells = 100 + ngenes = 50 + adata = _create_adata(celltypes=CELLTYPES[:2], ncells=ncells, ngenes=ngenes, + assembly=ASSEMBLY_HUMAN) + return adata diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.tsv b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.tsv new file mode 100644 index 000000000..4b12455d6 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.tsv @@ -0,0 +1,3 @@ +source target target_id +adventitial cell adventitial cell CL:0002503 +endothelial cell endothelial cell CL:0000115 diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml new file mode 100644 index 000000000..3c64351e8 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml @@ -0,0 +1,52 @@ +dataset_structure: + dataset_index: 1 + sample_fns: +dataset_wise: + author: + - "mock1" + default_embedding: + doi_journal: "no_doi_mock1" + doi_preprint: + download_url_data: "" + download_url_meta: "" + normalization: "raw" + primary_data: + year: 2021 +dataset_or_observation_wise: + assay_sc: "10x technology" + assay_sc_obs_key: + assay_differentiation: + assay_differentiation_obs_key: + assay_type_differentiation: + assay_type_differentiation_obs_key: + bio_sample: + bio_sample_obs_key: + cell_line: + cell_line_obs_key: + development_stage: + development_stage_obs_key: + disease: "healthy" + disease_obs_key: + ethnicity: + ethnicity_obs_key: + individual: + individual_obs_key: + organ: "lung" + organ_obs_key: + organism: "human" + organism_obs_key: + sample_source: "primary_tissue" + sample_source_obs_key: + sex: + sex_obs_key: + state_exact: + state_exact_obs_key: + tech_sample: + tech_sample_obs_key: +observation_wise: + cell_types_original_obs_key: "free_annotation" +feature_wise: + gene_id_ensembl_var_key: "index" + gene_id_symbols_var_key: +meta: + version: "1.0" diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/__init__.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/__init__.py new file mode 100644 index 000000000..b1d5b2c2b --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/__init__.py @@ -0,0 +1 @@ +FILE_PATH = __file__ diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.py new file mode 100644 index 000000000..e8b2921c7 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.py @@ -0,0 +1,12 @@ +import anndata + +from sfaira.unit_tests.data_for_tests.loaders.consts import ASSEMBLY_MOUSE, CELLTYPES +from sfaira.unit_tests.data_for_tests.loaders.utils import _create_adata + + +def load(data_dir, sample_fn, **kwargs) -> anndata.AnnData: + ncells = 100 + ngenes = 70 + adata = _create_adata(celltypes=CELLTYPES[3:6], ncells=ncells, ngenes=ngenes, + assembly=ASSEMBLY_MOUSE) + return adata diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.tsv b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.tsv new file mode 100644 index 000000000..8cac1011a --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.tsv @@ -0,0 +1,4 @@ +source target target_id +acinar cell pancreatic acinar cell CL:0002064 +alpha cell pancreatic A cell CL:0000171 +beta cell type B pancreatic cell CL:0000169 diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml new file mode 100644 index 000000000..436de0756 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml @@ -0,0 +1,52 @@ +dataset_structure: + dataset_index: 1 + sample_fns: +dataset_wise: + author: + - "mock3" + default_embedding: + doi_journal: "no_doi_mock2" + doi_preprint: + download_url_data: "" + download_url_meta: "" + normalization: "raw" + primary_data: + year: 2021 +dataset_or_observation_wise: + assay_sc: "10x technology" + assay_sc_obs_key: + assay_differentiation: + assay_differentiation_obs_key: + assay_type_differentiation: + assay_type_differentiation_obs_key: + bio_sample: + bio_sample_obs_key: + cell_line: + cell_line_obs_key: + development_stage: + development_stage_obs_key: + disease: "healthy" + disease_obs_key: + ethnicity: + ethnicity_obs_key: + individual: + individual_obs_key: + organ: "pancreas" + organ_obs_key: + organism: "mouse" + organism_obs_key: + sample_source: "primary_tissue" + sample_source_obs_key: + sex: + sex_obs_key: + state_exact: + state_exact_obs_key: + tech_sample: + tech_sample_obs_key: +observation_wise: + cell_types_original_obs_key: "free_annotation" +feature_wise: + gene_id_ensembl_var_key: "index" + gene_id_symbols_var_key: +meta: + version: "1.0" diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/__init__.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/__init__.py new file mode 100644 index 000000000..b1d5b2c2b --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/__init__.py @@ -0,0 +1 @@ +FILE_PATH = __file__ diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.py new file mode 100644 index 000000000..2038c1bf2 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.py @@ -0,0 +1,12 @@ +import anndata + +from sfaira.unit_tests.data_for_tests.loaders.consts import ASSEMBLY_HUMAN, CELLTYPES +from sfaira.unit_tests.data_for_tests.loaders.utils import _create_adata + + +def load(data_dir, sample_fn, **kwargs) -> anndata.AnnData: + ncells = 100 + ngenes = 60 + adata = _create_adata(celltypes=CELLTYPES[:2], ncells=ncells, ngenes=ngenes, + assembly=ASSEMBLY_HUMAN) + return adata diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.tsv b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.tsv new file mode 100644 index 000000000..4b12455d6 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.tsv @@ -0,0 +1,3 @@ +source target target_id +adventitial cell adventitial cell CL:0002503 +endothelial cell endothelial cell CL:0000115 diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml new file mode 100644 index 000000000..e2f876bff --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml @@ -0,0 +1,52 @@ +dataset_structure: + dataset_index: 1 + sample_fns: +dataset_wise: + author: + - "mock2" + default_embedding: + doi_journal: "no_doi_mock3" + doi_preprint: + download_url_data: "" + download_url_meta: "" + normalization: "raw" + primary_data: + year: 2021 +dataset_or_observation_wise: + assay_sc: "10x technology" + assay_sc_obs_key: + assay_differentiation: + assay_differentiation_obs_key: + assay_type_differentiation: + assay_type_differentiation_obs_key: + bio_sample: + bio_sample_obs_key: + cell_line: + cell_line_obs_key: + development_stage: + development_stage_obs_key: + disease: "healthy" + disease_obs_key: + ethnicity: + ethnicity_obs_key: + individual: + individual_obs_key: + organ: "lung" + organ_obs_key: + organism: "human" + organism_obs_key: + sample_source: "primary_tissue" + sample_source_obs_key: + sex: + sex_obs_key: + state_exact: + state_exact_obs_key: + tech_sample: + tech_sample_obs_key: +observation_wise: + cell_types_original_obs_key: "free_annotation" +feature_wise: + gene_id_ensembl_var_key: "index" + gene_id_symbols_var_key: +meta: + version: "1.0" diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py new file mode 100644 index 000000000..46b508203 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py @@ -0,0 +1,60 @@ +import pydoc +import os +from typing import List +from warnings import warn +from sfaira.data import DatasetBase, DatasetGroup, DatasetSuperGroup, DatasetGroupDirectoryOriented + +from sfaira.unit_tests.directories import DIR_DATA_LOADERS_CACHE + + +class DatasetSuperGroupMock(DatasetSuperGroup): + """ + This is a DatasetSuperGroup which wraps the mock data loaders in the same directory. + + This class is designed to facilitate testing of code that requires data loaders without requiring raw data + downloads as all mock data loaders operate on data that is simulated in the `load()` functions. + A cache directory is established under ../cache. + + This class is a reduced and merged version of the sfaira loader super group class and the sfaira loader adapated + DatasetGroupDirectoryOriented. + """ + + dataset_groups: List[DatasetGroupDirectoryOriented] + + def __init__(self): + # Directory choice hyper-paramters: + dir_prefix = "d" + dir_exclude = [] + # Collect all data loaders from files in directory: + dataset_groups = [] + cwd = os.path.dirname(__file__) + for d in os.listdir(cwd): + if os.path.isdir(os.path.join(cwd, d)): # Iterate over mock studies (directories). + if d[:len(dir_prefix)] == dir_prefix and d not in dir_exclude: # Narrow down to data set directories + path_base = f"sfaira.unit_tests.data_for_tests.loaders.loaders.{d}" + path_dsg = pydoc.locate(f"{path_base}.FILE_PATH") + path_module = os.path.join(cwd, d) + for f in os.listdir(os.path.join(cwd, d)): # Iterate over loaders in mock study (file). + datasets = [] + if f.split(".")[-1] == "py" and f not in ["__init__.py"]: + file_module = ".".join(f.split(".")[:-1]) + if path_dsg is not None: + load_func = pydoc.locate(f"{path_base}.{file_module}.load") + fn_yaml = os.path.join(path_module, file_module + ".yaml") + x = DatasetBase( + data_path=DIR_DATA_LOADERS_CACHE, + meta_path=DIR_DATA_LOADERS_CACHE, + cache_path=DIR_DATA_LOADERS_CACHE, + load_func=load_func, + dict_load_func_annotation=None, + sample_fn=None, + sample_fns=None, + yaml_path=fn_yaml, + ) + x.load_ontology_class_map(fn=os.path.join(path_module, file_module + ".tsv")) + datasets.append(x) + else: + warn(f"DatasetGroupDirectoryOriented was None for {f}") + dataset_groups.append(DatasetGroup(datasets=dict([(x.id, x) for x in datasets]), + collection_id=d)) + super().__init__(dataset_groups=dataset_groups) diff --git a/sfaira/unit_tests/data_for_tests/loaders/utils.py b/sfaira/unit_tests/data_for_tests/loaders/utils.py new file mode 100644 index 000000000..a94ddf5d4 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/utils.py @@ -0,0 +1,84 @@ +import anndata +import scipy.sparse +import numpy as np +import os +import pandas as pd +import pathlib +from sfaira.versions.genomes import GenomeContainer + +from sfaira.unit_tests.directories import DIR_DATA_LOADERS_CACHE, DIR_DATA_LOADERS_STORE_DAO, \ + DIR_DATA_LOADERS_STORE_H5AD +from .consts import ASSEMBLY_HUMAN, ASSEMBLY_MOUSE +from .loaders import DatasetSuperGroupMock + +MATCH_TO_REFERENCE = {"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE} + + +def _create_adata(celltypes, ncells, ngenes, assembly) -> anndata.AnnData: + """ + Usesd by mock data loaders. + """ + gc = GenomeContainer(assembly=assembly) + gc.subset(biotype="protein_coding") + genes = gc.ensembl[:ngenes] + x = scipy.sparse.csc_matrix(np.random.randint(low=0, high=100, size=(ncells, ngenes))) + var = pd.DataFrame(index=genes) + obs = pd.DataFrame({ + "free_annotation": [celltypes[i] for i in np.random.choice(a=[0, 1], size=ncells, replace=True)] + }, index=["cell_" + str(i) for i in range(ncells)]) + adata = anndata.AnnData(X=x, obs=obs, var=var) + return adata + + +def _load_script(dsg, rewrite: bool, match_to_reference): + dsg.load(allow_caching=True, load_raw=rewrite) + dsg.streamline_features(remove_gene_version=True, match_to_reference=match_to_reference) + dsg.streamline_metadata(schema="sfaira", clean_obs=True, clean_var=True, clean_uns=True, clean_obs_names=True) + return dsg + + +def prepare_dsg(rewrite: bool = False, load: bool = True) -> DatasetSuperGroupMock: + """ + Prepares data set super group of mock data and returns instance. + + Use this do testing involving a data set group. + """ + # Make sure cache exists: + if not os.path.exists(DIR_DATA_LOADERS_CACHE): + pathlib.Path(DIR_DATA_LOADERS_CACHE).mkdir(parents=True, exist_ok=True) + dsg = DatasetSuperGroupMock() + if load: + dsg = _load_script(dsg=dsg, rewrite=rewrite, match_to_reference=MATCH_TO_REFERENCE) + return dsg + + +def prepare_store(store_format: str, rewrite: bool = False, rewrite_store: bool = False) -> str: + """ + Prepares mock data store and returns path to store. + + Use this do testing involving a data set store. + """ + dir_store_formatted = { + "dao": DIR_DATA_LOADERS_STORE_DAO, + "h5ad": DIR_DATA_LOADERS_STORE_H5AD, + }[store_format] + if not os.path.exists(dir_store_formatted): + pathlib.Path(dir_store_formatted).mkdir(parents=True, exist_ok=True) + dsg = prepare_dsg(rewrite=rewrite, load=False) + for k, ds in dsg.datasets.items(): + if store_format == "dao": + compression_kwargs = {"compressor": "default", "overwrite": True, "order": "C"} + else: + compression_kwargs = {} + if store_format == "dao": + anticipated_fn = os.path.join(dir_store_formatted, k) + elif store_format == "h5ad": + anticipated_fn = os.path.join(dir_store_formatted, k + ".h5ad") + else: + assert False + # Only rewrite if necessary + if rewrite_store or not os.path.exists(anticipated_fn): + ds = _load_script(dsg=ds, rewrite=rewrite, match_to_reference=MATCH_TO_REFERENCE) + ds.write_distributed_store(dir_cache=dir_store_formatted, store_format=store_format, dense=True, + chunks=128, compression_kwargs=compression_kwargs) + return dir_store_formatted diff --git a/sfaira/unit_tests/directories.py b/sfaira/unit_tests/directories.py new file mode 100644 index 000000000..f2c457c0f --- /dev/null +++ b/sfaira/unit_tests/directories.py @@ -0,0 +1,14 @@ +""" +All paths used throughout unit testing for temporary files. +""" + +import os + +DIR_TEMP = os.path.join(os.path.dirname(__file__), "temp") + +_DIR_DATA_LOADERS = os.path.join(DIR_TEMP, "loaders") +DIR_DATA_LOADERS_CACHE = os.path.join(_DIR_DATA_LOADERS, "cache") +DIR_DATA_LOADERS_STORE_DAO = os.path.join(_DIR_DATA_LOADERS, "store_dao") +DIR_DATA_LOADERS_STORE_H5AD = os.path.join(_DIR_DATA_LOADERS, "store_h5ad") +_DIR_DATA_DATABASES = os.path.join(DIR_TEMP, "databases") +DIR_DATA_DATABASES_CACHE = os.path.join(_DIR_DATA_DATABASES, "cache") diff --git a/sfaira/unit_tests/test_data/model_lookuptable.csv b/sfaira/unit_tests/test_data/model_lookuptable.csv deleted file mode 100644 index ca8a2eb5b..000000000 --- a/sfaira/unit_tests/test_data/model_lookuptable.csv +++ /dev/null @@ -1,3 +0,0 @@ -,model_id,url,md5 -1,embedding_mouse_lung_vae_theislab_0.1_0.1,some_url,some_md5 -2,celltype_mouse_lung_mlp_theislab_0.0.1_0.1,some_url,some_md5 diff --git a/sfaira/unit_tests/estimators/__init__.py b/sfaira/unit_tests/tests_by_submodule/__init__.py similarity index 100% rename from sfaira/unit_tests/estimators/__init__.py rename to sfaira/unit_tests/tests_by_submodule/__init__.py diff --git a/sfaira/unit_tests/trainer/__init__.py b/sfaira/unit_tests/tests_by_submodule/data/__init__.py similarity index 100% rename from sfaira/unit_tests/trainer/__init__.py rename to sfaira/unit_tests/tests_by_submodule/data/__init__.py diff --git a/sfaira/unit_tests/tests_by_submodule/data/test_clean_celltype_maps.py b/sfaira/unit_tests/tests_by_submodule/data/test_clean_celltype_maps.py new file mode 100644 index 000000000..caa5d3b0f --- /dev/null +++ b/sfaira/unit_tests/tests_by_submodule/data/test_clean_celltype_maps.py @@ -0,0 +1,8 @@ +from sfaira.data.dataloaders.loaders import DatasetSuperGroupLoaders + +# TODO export this into a maintenance module. +# def test_map_celltype_to_ontology(): +# # Paths do not matter here as data sets are not loaded for these operations. +# dsgl = DatasetSuperGroupLoaders(data_path="", meta_path="", cache_path="") +# for x in dsgl.dataset_groups: +# x.clean_ontology_class_map() diff --git a/sfaira/unit_tests/data/test_data_utils.py b/sfaira/unit_tests/tests_by_submodule/data/test_data_utils.py similarity index 100% rename from sfaira/unit_tests/data/test_data_utils.py rename to sfaira/unit_tests/tests_by_submodule/data/test_data_utils.py diff --git a/sfaira/unit_tests/tests_by_submodule/data/test_databases.py b/sfaira/unit_tests/tests_by_submodule/data/test_databases.py new file mode 100644 index 000000000..e4cbd32a1 --- /dev/null +++ b/sfaira/unit_tests/tests_by_submodule/data/test_databases.py @@ -0,0 +1,61 @@ +import os +import pytest +import shutil +from typing import List + +from sfaira.unit_tests.directories import DIR_DATA_DATABASES_CACHE +from sfaira.unit_tests.data_for_tests.databases.utils import prepare_dsg_database +from sfaira.unit_tests.data_for_tests.databases.consts import CELLXGENE_DATASET_ID +from sfaira.unit_tests.data_for_tests.loaders import ASSEMBLY_HUMAN, ASSEMBLY_MOUSE + + +# Execute this one first so that data sets are only downloaded once. Named test_a for this reason. +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [None, ["id", CELLXGENE_DATASET_ID], ]) +def test_a_dsgs_download(database: str, subset_args: List[str]): + """ + Tests if downloading of data base entries works. + + Warning, deletes entire database unit test cache. + """ + if os.path.exists(DIR_DATA_DATABASES_CACHE): + shutil.rmtree(DIR_DATA_DATABASES_CACHE) + dsg = prepare_dsg_database(database=database, download=False) + if subset_args is not None: + dsg.subset(key=subset_args[0], values=subset_args[1]) + dsg.download() + + +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ["organism", "human"], ]) +def test_dsgs_subset(database: str, subset_args: List[str]): + """ + Tests if subsetting results only in datasets of the desired characteristics. + """ + dsg = prepare_dsg_database(database=database) + dsg.subset(key=subset_args[0], values=subset_args[1]) + + +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ]) +@pytest.mark.parametrize("match_to_reference", [None, {"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, ]) +def test_dsgs_adata(database: str, subset_args: List[str], match_to_reference: dict): + dsg = prepare_dsg_database(database=database) + dsg.subset(key=subset_args[0], values=subset_args[1]) + dsg.load() + if match_to_reference is not None: + dsg.streamline_features(remove_gene_version=True, match_to_reference=match_to_reference) + dsg.streamline_metadata(schema="sfaira", clean_obs=True, clean_var=True, clean_uns=True, clean_obs_names=True) + _ = dsg.adata + + +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ]) +@pytest.mark.parametrize("match_to_reference", [{"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, ]) +@pytest.mark.parametrize("subset_genes_to_type", [None, "protein_coding", ]) +def test_dsgs_streamline_features(database: str, subset_args: List[str], match_to_reference: dict, + subset_genes_to_type: str): + dsg = prepare_dsg_database(database=database) + dsg.subset(key=subset_args[0], values=subset_args[1]) + dsg.load() + dsg.streamline_features(match_to_reference=match_to_reference, subset_genes_to_type=subset_genes_to_type) diff --git a/sfaira/unit_tests/data/test_dataset.py b/sfaira/unit_tests/tests_by_submodule/data/test_dataset.py similarity index 63% rename from sfaira/unit_tests/data/test_dataset.py rename to sfaira/unit_tests/tests_by_submodule/data/test_dataset.py index eff963998..c7e063e0c 100644 --- a/sfaira/unit_tests/data/test_dataset.py +++ b/sfaira/unit_tests/tests_by_submodule/data/test_dataset.py @@ -5,14 +5,12 @@ from sfaira.data import DatasetSuperGroup from sfaira.data import Universe -MOUSE_GENOME_ANNOTATION = "Mus_musculus.GRCm38.102" - -dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") -dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") +from sfaira.unit_tests.data_for_tests.loaders import ASSEMBLY_MOUSE, prepare_dsg +from sfaira.unit_tests.directories import DIR_TEMP, DIR_DATA_LOADERS_CACHE def test_dsgs_instantiate(): - _ = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + _ = Universe(data_path=DIR_DATA_LOADERS_CACHE, meta_path=DIR_DATA_LOADERS_CACHE, cache_path=DIR_DATA_LOADERS_CACHE) @pytest.mark.parametrize("organ", ["intestine", "ileum"]) @@ -20,9 +18,10 @@ def test_dsgs_subset_dataset_wise(organ: str): """ Tests if subsetting results only in datasets of the desired characteristics. """ - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=[organ]) + ds.load() for x in ds.dataset_groups: for k, v in x.datasets.items(): assert v.organism == "mouse", v.organism @@ -30,23 +29,19 @@ def test_dsgs_subset_dataset_wise(organ: str): def test_dsgs_config_write_load(): - fn = dir_data + "/config.csv" - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + fn = os.path.join(DIR_TEMP, "config.csv") + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=["lung"]) + ds.load() ds.write_config(fn=fn) - ds2 = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds2 = prepare_dsg() ds2.load_config(fn=fn) assert np.all(ds.ids == ds2.ids) -""" -TODO tests from here on down require cached data for mouse lung -""" - - def test_dsgs_adata(): - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=["lung"]) ds.load() @@ -54,7 +49,7 @@ def test_dsgs_adata(): def test_dsgs_load(): - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=["lung"]) ds.load() @@ -66,7 +61,7 @@ def test_dsgs_subset_cell_wise(organ: str, celltype: str): """ Tests if subsetting results only in datasets of the desired characteristics. """ - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=[organ]) ds.load() @@ -87,21 +82,21 @@ def test_dsgs_subset_cell_wise(organ: str, celltype: str): @pytest.mark.parametrize("clean_obs_names", [True, False]) def test_dsgs_streamline_metadata(out_format: str, uns_to_obs: bool, clean_obs: bool, clean_var: bool, clean_uns: bool, clean_obs_names: bool): - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=["lung"]) ds.load() - ds.streamline_features(remove_gene_version=False, match_to_reference=MOUSE_GENOME_ANNOTATION, + ds.streamline_features(remove_gene_version=False, match_to_reference=ASSEMBLY_MOUSE, subset_genes_to_type=None) - ds.streamline_metadata(schema=out_format, uns_to_obs=uns_to_obs, clean_obs=clean_obs, clean_var=clean_var, + ds.streamline_metadata(schema=out_format, clean_obs=clean_obs, clean_var=clean_var, clean_uns=clean_uns, clean_obs_names=clean_obs_names) -@pytest.mark.parametrize("match_to_reference", ["Mus_musculus.GRCm38.102", {"mouse": MOUSE_GENOME_ANNOTATION}]) +@pytest.mark.parametrize("match_to_reference", ["Mus_musculus.GRCm38.102", {"mouse": ASSEMBLY_MOUSE}]) @pytest.mark.parametrize("remove_gene_version", [False, True]) @pytest.mark.parametrize("subset_genes_to_type", [None, "protein_coding"]) def test_dsgs_streamline_features(match_to_reference: str, remove_gene_version: bool, subset_genes_to_type: str): - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=["lung"]) ds.load() @@ -109,23 +104,8 @@ def test_dsgs_streamline_features(match_to_reference: str, remove_gene_version: subset_genes_to_type=subset_genes_to_type) -@pytest.mark.parametrize("store", ["h5ad"]) -@pytest.mark.parametrize("dense", [False]) -@pytest.mark.parametrize("clean_obs", [False, True]) -def test_dsg_write_store(store: str, dense: bool, clean_obs: bool): - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) - ds.subset(key="organism", values=["mouse"]) - ds.subset(key="organ", values=["lung"]) - ds.load() - ds.streamline_features(remove_gene_version=True, match_to_reference={"mouse": MOUSE_GENOME_ANNOTATION}, - subset_genes_to_type="protein_coding") - ds.streamline_metadata(schema="sfaira", uns_to_obs=False, clean_obs=clean_obs, clean_var=True, clean_uns=True, - clean_obs_names=True) - ds.write_distributed_store(dir_cache=os.path.join(dir_data, "store"), store_format=store, dense=dense) - - def test_dsg_load(): - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=["lung"]) ds = DatasetSuperGroup(dataset_groups=[ds]) @@ -133,7 +113,7 @@ def test_dsg_load(): def test_dsg_adata(): - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) + ds = prepare_dsg(load=False) ds.subset(key="organism", values=["mouse"]) ds.subset(key="organ", values=["lung"]) ds = DatasetSuperGroup(dataset_groups=[ds]) diff --git a/sfaira/unit_tests/data/test_store.py b/sfaira/unit_tests/tests_by_submodule/data/test_store.py similarity index 58% rename from sfaira/unit_tests/data/test_store.py rename to sfaira/unit_tests/tests_by_submodule/data/test_store.py index f74395f33..a095669ca 100644 --- a/sfaira/unit_tests/data/test_store.py +++ b/sfaira/unit_tests/tests_by_submodule/data/test_store.py @@ -1,27 +1,16 @@ import anndata import dask.array +import h5py import numpy as np import os import pytest import scipy.sparse -import time from typing import List from sfaira.data import load_store -from sfaira.versions.genomes import GenomeContainer -from sfaira.unit_tests.utils import cached_store_writing +from sfaira.versions.genomes.genomes import GenomeContainer - -MOUSE_GENOME_ANNOTATION = "Mus_musculus.GRCm38.102" -HUMAN_GENOME_ANNOTATION = "Homo_sapiens.GRCh38.102" - -dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") -dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data", "meta") - - -""" -Tests from here on down require cached data for mouse lung -""" +from sfaira.unit_tests.data_for_tests.loaders import ASSEMBLY_MOUSE, prepare_dsg, prepare_store @pytest.mark.parametrize("store_format", ["h5ad", "dao"]) @@ -29,48 +18,57 @@ def test_fatal(store_format: str): """ Test if basic methods abort. """ - store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=MOUSE_GENOME_ANNOTATION, - store_format=store_format) + store_path = prepare_store(store_format=store_format) store = load_store(cache_path=store_path, store_format=store_format) store.subset(attr_key="organism", values=["mouse"]) - store.subset(attr_key="assay_sc", values=["10x sequencing"]) _ = store.n_obs _ = store.n_vars _ = store.var_names _ = store.shape _ = store.obs - _ = store.indices - _ = store.genome_container - _ = store.n_counts(idx=[1, 3]) + _ = store.stores["mouse"].indices + _ = store.genome_containers @pytest.mark.parametrize("store_format", ["h5ad", "dao"]) -@pytest.mark.parametrize("dataset", ["mouse_lung_2019_10xsequencing_pisco_022_10.1101/661728"]) -def test_data(store_format: str, dataset: str): +def test_data(store_format: str): """ Test if the data exposed by the store are the same as in the original Dataset instance after streamlining. """ - store_path, ds = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=MOUSE_GENOME_ANNOTATION, - store_format=store_format, return_ds=True) + # Run standard streamlining workflow on dsg and compare to object relayed via store. + # Prepare dsg. + dsg = prepare_dsg(rewrite=False, load=True) + # Prepare store. + # Rewriting store to avoid mismatch of randomly generated data in cache and store. + store_path = prepare_store(store_format=store_format, rewrite=False, rewrite_store=True) store = load_store(cache_path=store_path, store_format=store_format) - dataset_key_reduced = dataset.split("_10.")[0] - store.subset(attr_key="id", values=[dataset_key_reduced]) - adata_store = store.adata_by_key[dataset] - adata_ds = ds.datasets[dataset].adata - # Check .X - x_store = adata_store.X + store.subset(attr_key="doi_journal", values=["no_doi_mock1"]) + dataset_id = store.adata_by_key[list(store.indices.keys())[0]].uns["id"] + adata_store = store.adata_by_key[dataset_id] + x_store = store.data_by_key[dataset_id] + adata_ds = dsg.datasets[dataset_id].adata x_ds = adata_ds.X.todense() if isinstance(x_store, dask.array.Array): x_store = x_store.compute() + if isinstance(x_store, h5py.Dataset): + # Need to load sparse matrix into memory if it comes from a backed anndata object. + x_store = x_store[:, :] if isinstance(x_store, anndata._core.sparse_dataset.SparseDataset): # Need to load sparse matrix into memory if it comes from a backed anndata object. x_store = x_store[:, :] if isinstance(x_store, scipy.sparse.csr_matrix): x_store = x_store.todense() + if isinstance(x_ds, anndata._core.sparse_dataset.SparseDataset): + # Need to load sparse matrix into memory if it comes from a backed anndata object. + x_ds = x_ds[:, :] + if isinstance(x_ds, scipy.sparse.csr_matrix): + x_ds = x_ds.todense() # Check that non-zero elements are the same: - assert np.all(np.where(x_store > 0)[0] == np.where(x_ds > 0)[0]) - assert np.all(np.where(x_store > 0)[1] == np.where(x_ds > 0)[1]) - assert np.all(x_store - x_ds == 0.) + assert x_store.shape[0] == x_ds.shape[0] + assert x_store.shape[1] == x_ds.shape[1] + assert np.all(np.where(x_store > 0)[0] == np.where(x_ds > 0)[0]), (np.sum(x_store > 0), np.sum(x_ds > 0)) + assert np.all(np.where(x_store > 0)[1] == np.where(x_ds > 0)[1]), (np.sum(x_store > 0), np.sum(x_ds > 0)) + assert np.all(x_store - x_ds == 0.), (np.sum(x_store), np.sum(x_ds)) assert x_store.dtype == x_ds.dtype # Note: Do not run test on sum across entire object if dtype is float32 as this can result in test failures because # of float overflows. @@ -99,64 +97,56 @@ def test_config(store_format: str): """ Test that data set config files can be set, written and recovered. """ - store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=MOUSE_GENOME_ANNOTATION, - store_format=store_format) + store_path = prepare_store(store_format=store_format) config_path = os.path.join(store_path, "config_lung") store = load_store(cache_path=store_path, store_format=store_format) store.subset(attr_key="organism", values=["mouse"]) - store.subset(attr_key="assay_sc", values=["10x sequencing"]) + store.subset(attr_key="assay_sc", values=["10x technology"]) store.write_config(fn=config_path) store2 = load_store(cache_path=store_path, store_format=store_format) store2.load_config(fn=config_path + ".pickle") assert np.all(store.indices.keys() == store2.indices.keys()) - assert np.all([np.all(store.indices[k] == store2.indices[k]) for k in store.indices.keys()]) + assert np.all([np.all(store.indices[k] == store2.indices[k]) + for k in store.indices.keys()]) @pytest.mark.parametrize("store_format", ["h5ad", "dao"]) -@pytest.mark.parametrize("idx", [np.array([2, 1020, 3, 20000, 20100]), - np.concatenate([np.arange(150, 200), np.array([1, 100, 2003, 33])])]) +@pytest.mark.parametrize("idx", [np.arange(1, 10), + np.concatenate([np.arange(30, 50), np.array([1, 4, 98])])]) @pytest.mark.parametrize("batch_size", [1, 7]) -@pytest.mark.parametrize("obs_keys", [[], ["cell_ontology_class"]]) -@pytest.mark.parametrize("gc", [(None, {}), (MOUSE_GENOME_ANNOTATION, {"biotype": "protein_coding"})]) +@pytest.mark.parametrize("obs_keys", [["cell_ontology_class"]]) @pytest.mark.parametrize("randomized_batch_access", [True, False]) -def test_generator_shapes(store_format: str, idx, batch_size: int, obs_keys: List[str], gc: tuple, - randomized_batch_access: bool): +def test_generator_shapes(store_format: str, idx, batch_size: int, obs_keys: List[str], randomized_batch_access: bool): """ Test generators queries do not throw errors and that output shapes are correct. """ - assembly, subset = gc - store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=MOUSE_GENOME_ANNOTATION, - store_format=store_format) + store_path = prepare_store(store_format=store_format) store = load_store(cache_path=store_path, store_format=store_format) store.subset(attr_key="organism", values=["mouse"]) - if assembly is not None: - gc = GenomeContainer(assembly=assembly) - gc.subset(**subset) - store.genome_container = gc - g = store.generator( - idx=idx, + gc = GenomeContainer(assembly=ASSEMBLY_MOUSE) + gc.subset(**{"biotype": "protein_coding"}) + store.genome_containers = gc + g, _ = store.generator( + idx={"mouse": idx}, batch_size=batch_size, obs_keys=obs_keys, randomized_batch_access=randomized_batch_access, ) nobs = len(idx) if idx is not None else store.n_obs batch_sizes = [] - t0 = time.time() x = None obs = None + counter = 0 for i, z in enumerate(g()): + counter += 1 x_i, obs_i = z assert x_i.shape[0] == obs_i.shape[0] if i == 0: x = x_i obs = obs_i batch_sizes.append(x_i.shape[0]) - tdelta = time.time() - t0 - print(f"time for iterating over generator:" - f" {tdelta}s for {np.sum(batch_sizes)} cells in {len(batch_sizes)} batches," - f" {tdelta / len(batch_sizes)}s per batch.") - assert x.shape[1] == store.n_vars, (x.shape, store.n_vars) + assert counter > 0 + assert x.shape[1] == store.n_vars["mouse"], (x.shape, store.n_vars["mouse"]) assert obs.shape[1] == len(obs_keys), (obs.shape, obs_keys) assert np.sum(batch_sizes) == nobs, (batch_sizes, nobs) - if assembly is not None: - assert x.shape[1] == gc.n_var, (x.shape, gc.n_var) + assert x.shape[1] == gc.n_var, (x.shape, gc.n_var) diff --git a/sfaira/unit_tests/tests_by_submodule/estimators/__init__.py b/sfaira/unit_tests/tests_by_submodule/estimators/__init__.py new file mode 100644 index 000000000..680468bc4 --- /dev/null +++ b/sfaira/unit_tests/tests_by_submodule/estimators/__init__.py @@ -0,0 +1 @@ +from .test_estimator import TARGETS, TestHelperEstimatorBase diff --git a/sfaira/unit_tests/estimators/custom.obo b/sfaira/unit_tests/tests_by_submodule/estimators/custom.obo similarity index 100% rename from sfaira/unit_tests/estimators/custom.obo rename to sfaira/unit_tests/tests_by_submodule/estimators/custom.obo diff --git a/sfaira/unit_tests/estimators/test_estimator.py b/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py similarity index 77% rename from sfaira/unit_tests/estimators/test_estimator.py rename to sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py index b404ef079..1aba8f02d 100644 --- a/sfaira/unit_tests/estimators/test_estimator.py +++ b/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py @@ -4,20 +4,20 @@ import os import pandas as pd import pytest -import time from typing import Union -from sfaira.data import load_store, DistributedStoreBase +from sfaira.consts import AdataIdsSfaira, CACHE_DIR +from sfaira.data import DistributedStoreSingleFeatureSpace, load_store from sfaira.estimators import EstimatorKeras, EstimatorKerasCelltype, EstimatorKerasEmbedding -from sfaira.versions.genomes import CustomFeatureContainer +from sfaira.versions.genomes.genomes import CustomFeatureContainer from sfaira.versions.metadata import OntologyOboCustom from sfaira.versions.topologies import TopologyContainer -from sfaira.unit_tests.utils import cached_store_writing, simulate_anndata -dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") -dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data", "meta") -cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), - "cache", "genomes") +from sfaira.unit_tests.data_for_tests.loaders.consts import CELLTYPES +from sfaira.unit_tests.data_for_tests.loaders.utils import prepare_dsg, prepare_store +from sfaira.unit_tests.directories import DIR_TEMP + +CACHE_DIR_GENOMES = os.path.join(CACHE_DIR, "genomes") ASSEMBLY = { "mouse": "Mus_musculus.GRCm38.102", @@ -27,9 +27,10 @@ "mouse": ["ENSMUSG00000000003", "ENSMUSG00000000028"], "human": ["ENSG00000000003", "ENSG00000000005"], } -TARGETS = ["T cell", "CD4-positive helper T cell", "stromal cell", "UNKNOWN"] -TARGET_UNIVERSE = ["CD4-positive helper T cell", "stromal cell"] -ASSAYS = ["10x sequencing", "Smart-seq2"] +TARGETS = CELLTYPES +TARGET_UNIVERSE = CELLTYPES + +ASSAYS = ["10x technology", "Smart-seq2"] TOPOLOGY_EMBEDDING_MODEL = { @@ -64,9 +65,35 @@ } -class HelperEstimatorBase: +class TestHelperEstimatorBase: - data: Union[anndata.AnnData, DistributedStoreBase] + adata_ids: AdataIdsSfaira + data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace] + tc: TopologyContainer + + def load_adata(self, organism="human", organ=None): + dsg = prepare_dsg() + if organism is not None: + dsg.subset(key="organism", values=organism) + if organ is not None: + dsg.subset(key="organ", values=organ) + self.adata_ids = dsg.dataset_groups[0]._adata_ids + self.data = dsg.adata + + def load_store(self, organism="human", organ=None): + store_path = prepare_store(store_format="dao") + store = load_store(cache_path=store_path, store_format="dao") + if organism is not None: + store.subset(attr_key="organism", values=organism) + if organ is not None: + store.subset(attr_key="organ", values=organ) + self.adata_ids = store._adata_ids_sfaira + self.data = store.stores[organism] + + +class TestHelperEstimatorKeras(TestHelperEstimatorBase): + + data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace] estimator: Union[EstimatorKeras] model_type: str tc: TopologyContainer @@ -78,28 +105,6 @@ class HelperEstimatorBase: basic_estimator_test(). See _test_call() for an example. """ - def _simulate(self) -> anndata.AnnData: - """ - Simulate basic data example used for unit test. - - :return: Simulated data set. - """ - return simulate_anndata(n_obs=100, assays=ASSAYS, genes=self.tc.gc.ensembl, targets=TARGETS) - - def load_adata(self): - """ - Sets attribute .data with simulated data. - """ - self.data = self._simulate() - - def load_store(self, organism, organ): - store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=ASSEMBLY[organism], - organism=organism, organ=organ) - store = load_store(cache_path=store_path) - store.subset(attr_key="organism", values=organism) - store.subset(attr_key="organ", values=organ) - self.data = store - @abc.abstractmethod def init_topology(self, model_type: str, feature_space: str, organism: str): pass @@ -130,13 +135,13 @@ def estimator_train(self, test_split, randomized_batch_access): def basic_estimator_test(self, test_split): pass - def load_estimator(self, model_type, data_type, feature_space, test_split, organism="mouse", organ="lung"): + def load_estimator(self, model_type, data_type, feature_space, test_split, organism="human"): self.init_topology(model_type=model_type, feature_space=feature_space, organism=organism) np.random.seed(1) if data_type == "adata": self.load_adata() else: - self.load_store(organism=organism, organ=organ) + self.load_store(organism=organism) self.init_estimator(test_split=test_split) def fatal_estimator_test(self, model_type, data_type, test_split=0.1, feature_space="small"): @@ -146,7 +151,7 @@ def fatal_estimator_test(self, model_type, data_type, test_split=0.1, feature_sp self.basic_estimator_test(test_split=test_split) -class HelperEstimatorKerasEmbedding(HelperEstimatorBase): +class HelperEstimatorKerasEmbedding(TestHelperEstimatorKeras): estimator: EstimatorKerasEmbedding model_type: str @@ -157,7 +162,7 @@ def init_topology(self, model_type: str, feature_space: str, organism: str): if feature_space == "full": # Read 500 genes (not full protein coding) to compromise between being able to distinguish observations # and reducing run time of unit tests. - tab = pd.read_csv(os.path.join(cache_dir, ASSEMBLY[organism] + ".csv")) + tab = pd.read_csv(os.path.join(CACHE_DIR_GENOMES, ASSEMBLY[organism] + ".csv")) genes_full = tab.loc[tab["gene_biotype"].values == "protein_coding", "gene_id"].values[:500].tolist() topology["input"]["genes"] = ["ensg", genes_full] else: @@ -174,7 +179,8 @@ def init_topology(self, model_type: str, feature_space: str, organism: str): def init_estimator(self, test_split): self.estimator = EstimatorKerasEmbedding( data=self.data, - model_dir=None, + model_dir=DIR_TEMP, + cache_path=DIR_TEMP, model_id="testid", model_topology=self.tc ) @@ -200,9 +206,10 @@ def basic_estimator_test(self, test_split=0.1): assert np.allclose(prediction_embed, new_prediction_embed, rtol=1e-6, atol=1e-6) -class HelperEstimatorKerasCelltype(HelperEstimatorBase): +class TestHelperEstimatorKerasCelltype(TestHelperEstimatorKeras): estimator: EstimatorKerasCelltype + nleaves: int model_type: str tc: TopologyContainer @@ -218,40 +225,34 @@ def init_topology(self, model_type: str, feature_space: str, organism: str): def init_estimator(self, test_split): tc = self.tc - if isinstance(self.data, DistributedStoreBase): + if isinstance(self.data, DistributedStoreSingleFeatureSpace): # Reset leaves below: tc.topology["output"]["targets"] = None self.estimator = EstimatorKerasCelltype( data=self.data, - model_dir=None, + model_dir=DIR_TEMP, + cache_path=DIR_TEMP, model_id="testid", model_topology=tc ) - if isinstance(self.data, DistributedStoreBase): - leaves = self.estimator.celltype_universe.onto_cl.get_effective_leaves( - x=[x for x in self.data.obs[self.data._adata_ids_sfaira.cellontology_class].values - if x != self.data._adata_ids_sfaira.unknown_celltype_identifier] - ) - self.nleaves = len(leaves) - self.estimator.celltype_universe.onto_cl.leaves = leaves - else: - self.nleaves = None + leaves = self.estimator.celltype_universe.onto_cl.get_effective_leaves( + x=[x for x in self.data.obs[self.adata_ids.cellontology_class].values + if x != self.adata_ids.unknown_celltype_identifier] + ) + self.nleaves = len(leaves) + self.estimator.celltype_universe.onto_cl.leaves = leaves self.estimator.init_model() self.estimator.split_train_val_test(test_split=test_split, val_split=0.1) def basic_estimator_test(self, test_split=0.1): _ = self.estimator.evaluate() prediction_output = self.estimator.predict() - if isinstance(self.estimator.data, anndata.AnnData): - assert prediction_output.shape[1] == len(TARGET_UNIVERSE), prediction_output.shape - else: - assert prediction_output.shape[1] == self.nleaves, prediction_output.shape + assert prediction_output.shape[1] == self.nleaves, prediction_output.shape weights = self.estimator.model.training_model.get_weights() self.estimator.save_weights_to_cache() self.estimator.load_weights_from_cache() new_prediction_output = self.estimator.predict() new_weights = self.estimator.model.training_model.get_weights() - print(self.estimator.model.training_model.summary()) for i in range(len(weights)): if not np.any(np.isnan(weights[i])): assert np.allclose(weights[i], new_weights[i], rtol=1e-6, atol=1e-6) @@ -259,7 +260,9 @@ def basic_estimator_test(self, test_split=0.1): assert np.allclose(prediction_output, new_prediction_output, rtol=1e-6, atol=1e-6) -class HelperEstimatorKerasCelltypeCustomObo(HelperEstimatorKerasCelltype): +class HelperEstimatorKerasCelltypeCustomObo(TestHelperEstimatorKerasCelltype): + + custom_types = ["MYONTO:01", "MYONTO:02", "MYONTO:03"] def init_obo_custom(self) -> OntologyOboCustom: return OntologyOboCustom(obo=os.path.join(os.path.dirname(__file__), "custom.obo")) @@ -271,16 +274,40 @@ def init_genome_custom(self, n_features) -> CustomFeatureContainer: "gene_biotype": ["embedding" for _ in range(n_features)], })) + def load_adata(self, organism="human", organ=None): + dsg = prepare_dsg(load=False) + if organism is not None: + dsg.subset(key="organism", values=organism) + if organ is not None: + dsg.subset(key="organ", values=organ) + self.adata_ids = dsg.dataset_groups[0]._adata_ids + # Use mock data loading to generate base line object: + dsg.load() + self.data = dsg.datasets[list(dsg.datasets.keys())[0]].adata + # - Subset to target feature space size: + self.data = self.data[:, :self.tc.gc.n_var].copy() + # - Add in custom cell types: + self.data.obs[self.adata_ids.cellontology_class] = [ + self.custom_types[np.random.randint(0, len(self.custom_types))] + for _ in range(self.data.n_obs) + ] + self.data.obs[self.adata_ids.cellontology_id] = self.data.obs[self.adata_ids.cellontology_class] + # - Add in custom features: + self.data.var_names = ["dim_" + str(i) for i in range(self.data.n_vars)] + self.data.var[self.adata_ids.gene_id_ensembl] = ["dim_" + str(i) for i in range(self.data.n_vars)] + self.data.var[self.adata_ids.gene_id_symbols] = ["dim_" + str(i) for i in range(self.data.n_vars)] + def init_topology_custom(self, model_type: str, n_features): topology = TOPOLOGY_CELLTYPE_MODEL.copy() topology["model_type"] = model_type topology["input"]["genome"] = "custom" topology["input"]["genes"] = ["biotype", "embedding"] topology["output"]["cl"] = "custom" - topology["output"]["targets"] = ["MYONTO:02", "MYONTO:03"] + topology["output"]["targets"] = self.custom_types[1:] if model_type == "mlp": topology["hyper_parameters"]["units"] = (2,) self.model_type = model_type + self.nleaves = len(topology["output"]["targets"]) gc = self.init_genome_custom(n_features=n_features) self.tc = TopologyContainer(topology=topology, topology_id="0.0.1", custom_genome_constainer=gc) @@ -288,11 +315,11 @@ def fatal_estimator_test_custom(self): self.init_topology_custom(model_type="mlp", n_features=50) obo = self.init_obo_custom() np.random.seed(1) - self.data = simulate_anndata(n_obs=100, genes=self.tc.gc.ensembl, - targets=["MYONTO:01", "MYONTO:02", "MYONTO:03"], obo=obo) + self.load_adata() self.estimator = EstimatorKerasCelltype( data=self.data, - model_dir=None, + model_dir=DIR_TEMP, + cache_path=DIR_TEMP, model_id="testid", model_topology=self.tc, celltype_ontology=obo, @@ -310,13 +337,13 @@ def test_for_fatal_linear(data_type): test_estim.fatal_estimator_test(model_type="linear", data_type=data_type) -@pytest.mark.parametrize("data_type", ["adata"]) +@pytest.mark.parametrize("data_type", ["store"]) def test_for_fatal_ae(data_type): test_estim = HelperEstimatorKerasEmbedding() test_estim.fatal_estimator_test(model_type="ae", data_type=data_type) -@pytest.mark.parametrize("data_type", ["adata"]) +@pytest.mark.parametrize("data_type", ["store"]) def test_for_fatal_vae(data_type): test_estim = HelperEstimatorKerasEmbedding() test_estim.fatal_estimator_test(model_type="vae", data_type=data_type) @@ -327,13 +354,13 @@ def test_for_fatal_vae(data_type): @pytest.mark.parametrize("data_type", ["adata", "store"]) def test_for_fatal_mlp(data_type): - test_estim = HelperEstimatorKerasCelltype() + test_estim = TestHelperEstimatorKerasCelltype() test_estim.fatal_estimator_test(model_type="mlp", data_type=data_type) -@pytest.mark.parametrize("data_type", ["adata"]) +@pytest.mark.parametrize("data_type", ["store"]) def test_for_fatal_marker(data_type): - test_estim = HelperEstimatorKerasCelltype() + test_estim = TestHelperEstimatorKerasCelltype() test_estim.fatal_estimator_test(model_type="marker", data_type=data_type) @@ -344,11 +371,9 @@ def test_for_fatal_mlp_custom(): # Test index sets -@pytest.mark.parametrize("organism", ["human"]) -@pytest.mark.parametrize("organ", ["lung"]) @pytest.mark.parametrize("batch_size", [1024, 2048, 4096]) @pytest.mark.parametrize("randomized_batch_access", [False, True]) -def test_dataset_size(organism: str, organ: str, batch_size: int, randomized_batch_access: bool): +def test_dataset_size(batch_size: int, randomized_batch_access: bool): """ Test that tf data set from estimator has same size as generator invoked directly from store based on number of observations in emitted batches. @@ -361,7 +386,7 @@ def test_dataset_size(organism: str, organ: str, batch_size: int, randomized_bat # Need full feature space here because observations are not necessarily different in small model testing feature # space with only two genes: test_estim.load_estimator(model_type="linear", data_type="store", feature_space="reduced", test_split=0.2, - organism=organism, organ=organ) + organism="human") idx_train = test_estim.estimator.idx_train shuffle_buffer_size = None if randomized_batch_access else 2 ds_train = test_estim.estimator._get_dataset(idx=idx_train, batch_size=batch_size, mode='eval', @@ -373,8 +398,8 @@ def test_dataset_size(organism: str, organ: str, batch_size: int, randomized_bat x_train_shape += x[0].shape[0] # Define raw store generator on train data to compare and check that it has the same size as tf generator exposed # by estimator: - g_train = test_estim.estimator.data.generator(idx=idx_train, batch_size=retrieval_batch_size, - randomized_batch_access=randomized_batch_access) + g_train, _ = test_estim.estimator.data.generator(idx=idx_train, batch_size=retrieval_batch_size, + randomized_batch_access=randomized_batch_access) x_train2_shape = 0 for x, _ in g_train(): x_train2_shape += x.shape[0] @@ -382,12 +407,10 @@ def test_dataset_size(organism: str, organ: str, batch_size: int, randomized_bat assert x_train_shape == len(idx_train) -@pytest.mark.parametrize("organism", ["mouse"]) -@pytest.mark.parametrize("organ", ["lung"]) @pytest.mark.parametrize("data_type", ["adata", "store"]) @pytest.mark.parametrize("randomized_batch_access", [False, True]) -@pytest.mark.parametrize("test_split", [0.3, {"assay_sc": "10x sequencing"}]) -def test_split_index_sets(organism: str, organ: str, data_type: str, randomized_batch_access: bool, test_split): +@pytest.mark.parametrize("test_split", [0.3, {"id": "human_lung_2021_10xtechnology_mock1_001_no_doi_mock1"}]) +def test_split_index_sets(data_type: str, randomized_batch_access: bool, test_split): """ Test that train, val, test split index sets are correct: @@ -400,13 +423,10 @@ def test_split_index_sets(organism: str, organ: str, data_type: str, randomized_ # Need full feature space here because observations are not necessarily different in small model testing feature # space with only two genes: test_estim.load_estimator(model_type="linear", data_type=data_type, test_split=test_split, feature_space="full", - organism=organism, organ=organ) + organism="human") idx_train = test_estim.estimator.idx_train idx_eval = test_estim.estimator.idx_eval idx_test = test_estim.estimator.idx_test - print(idx_train) - print(idx_eval) - print(idx_test) # 1) Assert that index assignment sets sum up to full data set: # Make sure that there are no repeated indices in each set. assert len(idx_train) == len(np.unique(idx_train)) @@ -414,17 +434,20 @@ def test_split_index_sets(organism: str, organ: str, data_type: str, randomized_ assert len(idx_test) == len(np.unique(idx_test)) assert len(idx_train) + len(idx_eval) + len(idx_test) == test_estim.data.n_obs, \ (len(idx_train), len(idx_eval), len(idx_test), test_estim.data.n_obs) - if isinstance(test_estim.data, DistributedStoreBase): + if isinstance(test_estim.data, DistributedStoreSingleFeatureSpace): assert np.sum([v.shape[0] for v in test_estim.data.adata_by_key.values()]) == test_estim.data.n_obs # 2) Assert that index assignments are exclusive to each split: assert len(set(idx_train).intersection(set(idx_eval))) == 0 assert len(set(idx_train).intersection(set(idx_test))) == 0 assert len(set(idx_test).intersection(set(idx_eval))) == 0 # 3) Check partition of index vectors over store data sets matches test split scenario: - if isinstance(test_estim.estimator.data, DistributedStoreBase): + if isinstance(test_estim.estimator.data, DistributedStoreSingleFeatureSpace): # Prepare data set-wise index vectors that are numbered in the same way as global split index vectors. - # See also EstimatorKeras.train and DistributedStoreBase.subset_cells_idx_global - idx_raw = test_estim.estimator.data.indices_global.values() + idx_raw = [] + counter = 0 + for v in test_estim.data.indices.values(): + idx_raw.append(np.arange(counter, counter + len(v))) + counter += len(v) if isinstance(test_split, float): # Make sure that indices from each split are in each data set: for i, z in enumerate([idx_train, idx_eval, idx_test]): @@ -454,57 +477,41 @@ def test_split_index_sets(organism: str, organ: str, data_type: str, randomized_ # Build numpy arrays of expression input data sets from tensorflow data sets directly from estimator. # These data sets are the most processed transformation of the data and stand directly in concat with the model. shuffle_buffer_size = None if randomized_batch_access else 2 - t0 = time.time() ds_train = test_estim.estimator._get_dataset(idx=idx_train, batch_size=1024, mode='eval', shuffle_buffer_size=shuffle_buffer_size, retrieval_batch_size=2048, randomized_batch_access=randomized_batch_access) - print(f"time for building training data set: {time.time() - t0}s") - t0 = time.time() ds_eval = test_estim.estimator._get_dataset(idx=idx_eval, batch_size=1024, mode='eval', shuffle_buffer_size=shuffle_buffer_size, retrieval_batch_size=2048, randomized_batch_access=randomized_batch_access) - print(f"time for building validation data set: {time.time() - t0}s") - t0 = time.time() ds_test = test_estim.estimator._get_dataset(idx=idx_test, batch_size=1024, mode='eval', shuffle_buffer_size=shuffle_buffer_size, retrieval_batch_size=2048, randomized_batch_access=randomized_batch_access) - print(f"time for building test data set: {time.time() - t0}s") # Create two copies of test data set to make sure that re-instantiation of a subset does not cause issues. ds_test2 = test_estim.estimator._get_dataset(idx=idx_test, batch_size=1024, mode='eval', shuffle_buffer_size=shuffle_buffer_size, retrieval_batch_size=2048, randomized_batch_access=randomized_batch_access) - print(f"time for building test data set: {time.time() - t0}s") x_train = [] x_eval = [] x_test = [] x_test2_shape = 0 - t0 = time.time() for x, _ in ds_train.as_numpy_iterator(): x_train.append(x[0]) x_train = np.concatenate(x_train, axis=0) - print(f"time for iterating over training data set: {time.time() - t0}s") - t0 = time.time() for x, _ in ds_eval.as_numpy_iterator(): x_eval.append(x[0]) x_eval = np.concatenate(x_eval, axis=0) - print(f"time for iterating over validation data set: {time.time() - t0}s") - t0 = time.time() for x, _ in ds_test.as_numpy_iterator(): x_test.append(x[0]) x_test = np.concatenate(x_test, axis=0) - print(f"time for iterating over test data set: {time.time() - t0}s") # Assert that duplicate of test data has the same shape: for x, _ in ds_test2: x_test2_shape += x[0].shape[0] assert x_test2_shape == x_test.shape[0] # Validate size of recovered numpy data sets: - print(test_estim.data.n_obs) - print(f"shapes expected {(len(idx_train), len(idx_eval), len(idx_test))}") - print(f"shapes received {(x_train.shape[0], x_eval.shape[0], x_test.shape[0])}") assert x_train.shape[0] + x_eval.shape[0] + x_test.shape[0] == test_estim.data.n_obs assert len(idx_train) + len(idx_eval) + len(idx_test) == test_estim.data.n_obs assert x_train.shape[0] == len(idx_train) diff --git a/sfaira/unit_tests/ui/__init__.py b/sfaira/unit_tests/tests_by_submodule/trainer/__init__.py similarity index 100% rename from sfaira/unit_tests/ui/__init__.py rename to sfaira/unit_tests/tests_by_submodule/trainer/__init__.py diff --git a/sfaira/unit_tests/trainer/test_trainer.py b/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py similarity index 56% rename from sfaira/unit_tests/trainer/test_trainer.py rename to sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py index e153fce2f..3fdf83365 100644 --- a/sfaira/unit_tests/trainer/test_trainer.py +++ b/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py @@ -4,18 +4,15 @@ from typing import Union from sfaira.data import load_store -from sfaira.ui import ModelZoo from sfaira.train import TrainModelCelltype, TrainModelEmbedding -from sfaira.unit_tests.utils import cached_store_writing, simulate_anndata - -dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") -dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") +from sfaira.ui import ModelZoo +from sfaira.versions.metadata import CelltypeUniverse, OntologyCl, OntologyUberon -ASSEMBLY = "Mus_musculus.GRCm38.102" -TARGETS = ["T cell", "stromal cell"] +from sfaira.unit_tests.tests_by_submodule.estimators import TestHelperEstimatorBase, TARGETS +from sfaira.unit_tests import DIR_TEMP -class HelperTrainerBase: +class HelperTrainerBase(TestHelperEstimatorBase): data: Union[anndata.AnnData, load_store] trainer: Union[TrainModelCelltype, TrainModelEmbedding] @@ -24,25 +21,6 @@ def __init__(self, zoo: ModelZoo): self.model_id = zoo.model_id self.tc = zoo.topology_container - def _simulate(self) -> anndata.AnnData: - """ - Simulate basic data example used for unit test. - - :return: Simulated data set. - """ - return simulate_anndata(n_obs=100, genes=self.tc.gc.ensembl, targets=TARGETS) - - def load_adata(self): - """ - Sets attribute .data with simulated data. - """ - self.data = self._simulate() - - def load_store(self): - store_path = cached_store_writing(dir_data=dir_data, dir_meta=dir_meta, assembly=ASSEMBLY, organism="mouse") - store = load_store(cache_path=store_path) - self.data = store - def load_data(self, data_type): np.random.seed(1) if data_type == "adata": @@ -50,19 +28,24 @@ def load_data(self, data_type): else: self.load_store() - def test_init(self, cls): + def test_init(self, cls, **kwargs): + if not os.path.exists(DIR_TEMP): + os.mkdir(DIR_TEMP) self.load_data(data_type="adata") self.trainer = cls( data=self.data, - model_path=dir_meta, + model_path=os.path.join(DIR_TEMP, "model"), + **kwargs ) self.trainer.zoo.model_id = self.model_id self.trainer.init_estim(override_hyperpar={}) def test_save(self): + if not os.path.exists(DIR_TEMP): + os.mkdir(DIR_TEMP) self.trainer.estimator.train(epochs=1, max_steps_per_epoch=1, test_split=0.1, validation_split=0.1, optimizer="adam", lr=0.005) - self.trainer.save(fn=os.path.join(dir_data, "trainer_test"), model=True, specific=True) + self.trainer.save(fn=os.path.join(DIR_TEMP, "trainer"), model=True, specific=True) def test_save_embedding(): @@ -75,9 +58,16 @@ def test_save_embedding(): def test_save_celltypes(): + # Create temporary cell type universe to give to trainer. + tmp_fn = os.path.join(DIR_TEMP, "universe_temp.csv") + cl = OntologyCl(branch="v2021-02-01") + uberon = OntologyUberon() + cu = CelltypeUniverse(cl=cl, uberon=uberon) + cu.write_target_universe(fn=tmp_fn, x=TARGETS) + del cu model_id = "celltype_human-lung-mlp-0.0.1-0.1_mylab" zoo = ModelZoo() zoo.model_id = model_id test_trainer = HelperTrainerBase(zoo=zoo) - test_trainer.test_init(cls=TrainModelCelltype) + test_trainer.test_init(cls=TrainModelCelltype, fn_target_universe=tmp_fn) test_trainer.test_save() diff --git a/sfaira/unit_tests/versions/__init__.py b/sfaira/unit_tests/tests_by_submodule/ui/__init__.py similarity index 100% rename from sfaira/unit_tests/versions/__init__.py rename to sfaira/unit_tests/tests_by_submodule/ui/__init__.py diff --git a/sfaira/unit_tests/ui/test_userinterface.py b/sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py similarity index 89% rename from sfaira/unit_tests/ui/test_userinterface.py rename to sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py index 69c5f1f02..37fd6ba49 100644 --- a/sfaira/unit_tests/ui/test_userinterface.py +++ b/sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py @@ -3,6 +3,7 @@ from typing import Union from sfaira.ui import UserInterface +from sfaira.unit_tests import DIR_TEMP class TestUi: @@ -33,5 +34,5 @@ def _test_basic(self): :return: """ - temp_fn = os.path.join(str(os.path.dirname(os.path.abspath(__file__))), '../test_data') + temp_fn = os.path.join(DIR_TEMP, "test_data") self.ui = UserInterface(custom_repo=temp_fn, sfaira_repo=False) diff --git a/sfaira/unit_tests/ui/test_zoo.py b/sfaira/unit_tests/tests_by_submodule/ui/test_zoo.py similarity index 82% rename from sfaira/unit_tests/ui/test_zoo.py rename to sfaira/unit_tests/tests_by_submodule/ui/test_zoo.py index d531bd14e..988374128 100644 --- a/sfaira/unit_tests/ui/test_zoo.py +++ b/sfaira/unit_tests/tests_by_submodule/ui/test_zoo.py @@ -1,9 +1,5 @@ -import os from sfaira.ui import ModelZoo -dir_data = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data") -dir_meta = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_data/meta") - def test_for_fatal_embedding(): model_id = "embedding_human-lung-linear-0.1-0.1_mylab" diff --git a/sfaira/unit_tests/tests_by_submodule/versions/__init__.py b/sfaira/unit_tests/tests_by_submodule/versions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sfaira/unit_tests/tests_by_submodule/versions/test_genomes.py b/sfaira/unit_tests/tests_by_submodule/versions/test_genomes.py new file mode 100644 index 000000000..3403fa45f --- /dev/null +++ b/sfaira/unit_tests/tests_by_submodule/versions/test_genomes.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest +from typing import Tuple, Union + +from sfaira.versions.genomes import GenomeContainer, translate_id_to_symbols, translate_symbols_to_id + +ASSEMBLY = "Mus_musculus.GRCm38.102" + +""" +GenomeContainer. +""" + + +@pytest.mark.parametrize("assembly", [ASSEMBLY]) +def test_gc_init(assembly: Union[str]): + """ + Tests different modes of initialisation for fatal errors. + """ + gc = GenomeContainer(assembly=assembly) + assert gc.organism == "mus_musculus" + + +@pytest.mark.parametrize("subset", [ + ({"biotype": "protein_coding"}, 21936), + ({"biotype": "lincRNA"}, 5629), + ({"biotype": "protein_coding,lincRNA"}, 21936 + 5629), + ({"symbols": "Gnai3,Pbsn,Cdc45"}, 3), + ({"ensg": "ENSMUSG00000000003,ENSMUSG00000000028"}, 2) +]) +def test_gc_subsetting(subset: Tuple[dict, int]): + """ + Tests if genome container is subsetted correctly. + """ + gc = GenomeContainer(assembly="Mus_musculus.GRCm38.102") + gc.subset(**subset[0]) + assert gc.n_var == subset[1] + assert len(gc.ensembl) == subset[1] + assert len(gc.symbols) == subset[1] + assert len(gc.biotype) == subset[1] + if list(subset[0].keys())[0] == "protein_coding": + assert np.all(gc.biotype == "protein_coding") + + +""" +Utils. +""" + + +@pytest.mark.parametrize("genes", [ + ("Adora3", "ENSMUSG00000000562"), # single string + (["Adora3", "Timp1"], ["ENSMUSG00000000562", "ENSMUSG00000001131"]), # list of strings + (["ADORA3", "timp1"], ["EnsmusG00000000562", "ENSMUSG00000001131"]), # list of strings with weird capitalization +]) +def test_translate_id_to_symbols(genes): + """ + Tests translate_id_to_symbols and translate_symbols_to_id for translation errors. + """ + x, y = genes + y_hat = translate_symbols_to_id(x=x, assembly="Mus_musculus.GRCm38.102") + # Correct target spelling of y: + y = [z.upper() for z in y] if isinstance(y, list) else y.upper() + assert np.all(y_hat == y) + y, x = genes + y_hat = translate_id_to_symbols(x=x, assembly="Mus_musculus.GRCm38.102") + # Correct target spelling of y: + y = [z[0].upper() + z[1:].lower() for z in y] if isinstance(y, list) else y[0].upper() + y[1:].lower() + assert np.all(y_hat == y) diff --git a/sfaira/unit_tests/versions/test_ontologies.py b/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py similarity index 79% rename from sfaira/unit_tests/versions/test_ontologies.py rename to sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py index 301bc66d6..32b81fb78 100644 --- a/sfaira/unit_tests/versions/test_ontologies.py +++ b/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py @@ -59,18 +59,14 @@ def test_cl_set_leaves(): assert set(leaves) == set(targets), leaves assert len(oc.node_ids) == 22 assert np.all([x in oc.convert_to_name(oc.node_ids) for x in targets]), oc.convert_to_name(oc.node_ids) - leaf_map_1 = oc.convert_to_name(oc.map_to_leaves(node="lymphocyte", include_self=True)) - leaf_map_2 = oc.convert_to_name(oc.map_to_leaves(node="lymphocyte", include_self=False)) - leaf_map_3 = oc.map_to_leaves(node="lymphocyte", include_self=True, return_type="idx") - leaf_map_4 = oc.convert_to_name(oc.map_to_leaves(node="T-helper 1 cell", include_self=True)) - leaf_map_5 = oc.map_to_leaves(node="T-helper 1 cell", include_self=False) - leaf_map_6 = oc.map_to_leaves(node="T-helper 1 cell", include_self=True, return_type="idx") + leaf_map_1 = oc.convert_to_name(oc.map_to_leaves(node="lymphocyte")) + leaf_map_2 = oc.map_to_leaves(node="lymphocyte", return_type="idx") + leaf_map_3 = oc.convert_to_name(oc.map_to_leaves(node="T-helper 1 cell")) + leaf_map_4 = oc.map_to_leaves(node="T-helper 1 cell", return_type="idx") assert set(leaf_map_1) == {"T-helper 1 cell", "T-helper 17 cell"} - assert set(leaf_map_2) == {"T-helper 1 cell", "T-helper 17 cell"} - assert np.all(leaf_map_3 == np.sort([oc.convert_to_name(oc.leaves).index(x) for x in list(leaf_map_1)])) - assert set(leaf_map_4) == {"T-helper 1 cell"} - assert leaf_map_5 == [] - assert np.all(leaf_map_6 == np.sort([oc.convert_to_name(oc.leaves).index(x) for x in list(leaf_map_4)])) + assert np.all(leaf_map_2 == np.sort([oc.convert_to_name(oc.leaves).index(x) for x in list(leaf_map_1)])) + assert set(leaf_map_3) == {"T-helper 1 cell"} + assert np.all(leaf_map_4 == np.sort([oc.convert_to_name(oc.leaves).index(x) for x in list(leaf_map_3)])) """ @@ -125,7 +121,7 @@ def test_sclc_nodes(): """ sclc = OntologySinglecellLibraryConstruction() assert "10x technology" in sclc.node_names - assert "10x 5' v3" in sclc.node_names + assert "10x 3' v3" in sclc.node_names assert "Smart-like" in sclc.node_names assert "Smart-seq2" in sclc.node_names assert "sci-plex" in sclc.node_names @@ -137,11 +133,10 @@ def test_sclc_is_a(): Tests if is-a relationships work correctly. """ sclc = OntologySinglecellLibraryConstruction() - assert sclc.is_a(query="10x v1", reference="10x technology") - assert sclc.is_a(query="10x 5' v3", reference="10x technology") - assert sclc.is_a(query="10x 5' v3", reference="10x v3") - assert not sclc.is_a(query="10x technology", reference="10x v1") - assert sclc.is_a(query="10x 5' v3", reference="single cell library construction") + assert sclc.is_a(query="10x 3' v3", reference="10x technology") + assert sclc.is_a(query="10x 3' v3", reference="10x 3' transcription profiling") + assert not sclc.is_a(query="10x technology", reference="10x 3' transcription profiling") + assert sclc.is_a(query="10x 3' v3", reference="single cell library construction") assert sclc.is_a(query="sci-plex", reference="single cell library construction") assert not sclc.is_a(query="sci-plex", reference="10x technology") diff --git a/sfaira/unit_tests/versions/test_universe.py b/sfaira/unit_tests/tests_by_submodule/versions/test_universe.py similarity index 55% rename from sfaira/unit_tests/versions/test_universe.py rename to sfaira/unit_tests/tests_by_submodule/versions/test_universe.py index 560eec945..23222b613 100644 --- a/sfaira/unit_tests/versions/test_universe.py +++ b/sfaira/unit_tests/tests_by_submodule/versions/test_universe.py @@ -1,6 +1,7 @@ import os from sfaira.versions.metadata import CelltypeUniverse, OntologyCl, OntologyUberon +from sfaira.unit_tests import DIR_TEMP """ CelltypeUniverse @@ -8,13 +9,16 @@ def test_universe_io(): - tmp_fn = "./universe_tempp.csv" + if not os.path.exists(DIR_TEMP): + os.mkdir(DIR_TEMP) + tmp_fn = os.path.join(DIR_TEMP, "universe_temp.csv") targets = ["stromal cell", "lymphocyte", "T-helper 1 cell", "T-helper 17 cell"] + leaves_target = ["stromal cell", "T-helper 1 cell", "T-helper 17 cell"] cl = OntologyCl(branch="v2021-02-01") uberon = OntologyUberon() cu = CelltypeUniverse(cl=cl, uberon=uberon) cu.write_target_universe(fn=tmp_fn, x=targets) cu.load_target_universe(fn=tmp_fn) os.remove(tmp_fn) - leaves = cu.leaves - assert set(leaves) == set(targets), (leaves, targets) + leaves = cu.onto_cl.convert_to_name(cu.onto_cl.leaves) + assert set(leaves) == set(leaves_target), (leaves, leaves_target) diff --git a/sfaira/unit_tests/utils.py b/sfaira/unit_tests/utils.py deleted file mode 100644 index 55667411e..000000000 --- a/sfaira/unit_tests/utils.py +++ /dev/null @@ -1,96 +0,0 @@ -import anndata -import numpy as np -import os -from typing import Tuple, Union - -from sfaira.consts import AdataIdsSfaira, OCS -from sfaira.data import Universe -from sfaira.versions.metadata import OntologyOboCustom - - -def simulate_anndata(genes, n_obs, targets=None, assays=None, obo: Union[None, OntologyOboCustom] = None) -> \ - anndata.AnnData: - """ - Simulate basic data example. - - :return: AnnData instance. - """ - adata_ids_sfaira = AdataIdsSfaira() - data = anndata.AnnData( - np.random.randint(low=0, high=100, size=(n_obs, len(genes))).astype(np.float32) - ) - if assays is not None: - data.obs[adata_ids_sfaira.assay_sc] = [ - assays[np.random.randint(0, len(assays))] - for _ in range(n_obs) - ] - if targets is not None: - data.obs[adata_ids_sfaira.cellontology_class] = [ - targets[np.random.randint(0, len(targets))] - for _ in range(n_obs) - ] - if obo is None: - data.obs[adata_ids_sfaira.cellontology_id] = [ - OCS.cellontology_class.convert_to_id(x) - if x not in [adata_ids_sfaira.unknown_celltype_identifier, - adata_ids_sfaira.not_a_cell_celltype_identifier] - else x - for x in data.obs[adata_ids_sfaira.cellontology_class].values - ] - else: - data.obs[adata_ids_sfaira.cellontology_id] = [ - obo.convert_to_id(x) - if x not in [adata_ids_sfaira.unknown_celltype_identifier, - adata_ids_sfaira.not_a_cell_celltype_identifier] - else x - for x in data.obs[adata_ids_sfaira.cellontology_class].values - ] - data.var[adata_ids_sfaira.gene_id_ensembl] = genes - return data - - -def cached_store_writing(dir_data, dir_meta, assembly, organism: str = "mouse", organ: str = "lung", - store_format: str = "h5ad", return_ds: bool = False) -> Union[str, Tuple[str, Universe]]: - """ - Writes a store if it does not already exist. - - :return: Path to store. - """ - adata_ids_sfaira = AdataIdsSfaira() - store_path = os.path.join(dir_data, "store") - if not os.path.exists(store_path): - os.mkdir(store_path) - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) - ds.subset(key=adata_ids_sfaira.organism, values=[organism]) - ds.subset(key=adata_ids_sfaira.organ, values=[organ]) - # Only load files that are not already in cache. - anticipated_files = np.unique([ - v.doi_journal[0] if isinstance(v.doi_journal, list) else v.doi_journal for k, v in ds.datasets.items() - if (not os.path.exists(os.path.join(store_path, v.doi_cleaned_id + "." + store_format)) and - store_format == "h5ad") or - (not os.path.exists(os.path.join(store_path, v.doi_cleaned_id)) and store_format == "dao") - ]).tolist() - ds.subset(key=adata_ids_sfaira.doi, values=anticipated_files) - ds.load(allow_caching=True) - ds.streamline_features(remove_gene_version=True, match_to_reference={organism: assembly}, - subset_genes_to_type="protein_coding") - ds.streamline_metadata(schema="sfaira", uns_to_obs=True, clean_obs=True, clean_var=True, clean_uns=True, - clean_obs_names=True) - if store_format == "zarr": - compression_kwargs = {"compressor": "default", "overwrite": True, "order": "C"} - else: - compression_kwargs = {} - ds.write_distributed_store(dir_cache=store_path, store_format=store_format, dense=store_format == "dao", - chunks=128, compression_kwargs=compression_kwargs) - if return_ds: - ds = Universe(data_path=dir_data, meta_path=dir_meta, cache_path=dir_data) - ds.subset(key=adata_ids_sfaira.organism, values=[organism]) - ds.subset(key=adata_ids_sfaira.organ, values=[organ]) - ds.load(allow_caching=True) - ds.streamline_features(remove_gene_version=True, match_to_reference={organism: assembly}, - subset_genes_to_type="protein_coding") - ds.streamline_metadata(schema="sfaira", uns_to_obs=True, clean_obs=True, clean_var=True, clean_uns=True, - clean_obs_names=True) - return store_path, ds - else: - return store_path diff --git a/sfaira/unit_tests/versions/test_genomes.py b/sfaira/unit_tests/versions/test_genomes.py deleted file mode 100644 index 512ca2417..000000000 --- a/sfaira/unit_tests/versions/test_genomes.py +++ /dev/null @@ -1,40 +0,0 @@ -import numpy as np -import pytest -from typing import Tuple, Union - -from sfaira.versions.genomes import GenomeContainer - -""" -GenomeContainer -""" - - -@pytest.mark.parametrize("organism", ["mouse"]) -@pytest.mark.parametrize("assembly", [None, "Mus_musculus.GRCm38.102"]) -def test_gc_init(organism: Union[str, None], assembly: Union[str, None]): - """ - Tests different modes of initialisation for fatal errors. - """ - gc = GenomeContainer(organism=organism, assembly=assembly) - assert gc.organism == "mus_musculus" - - -@pytest.mark.parametrize("subset", [ - ({"biotype": "protein_coding"}, 21936), - ({"biotype": "lincRNA"}, 5629), - ({"biotype": "protein_coding,lincRNA"}, 21936 + 5629), - ({"symbols": "Gnai3,Pbsn,Cdc45"}, 3), - ({"ensg": "ENSMUSG00000000003,ENSMUSG00000000028"}, 2) -]) -def test_gc_subsetting(subset: Tuple[dict, int]): - """ - Tests if genome container is subsetted correctly. - """ - gc = GenomeContainer(organism=None, assembly="Mus_musculus.GRCm38.102") - gc.subset(**subset[0]) - assert gc.n_var == subset[1] - assert len(gc.ensembl) == subset[1] - assert len(gc.symbols) == subset[1] - assert len(gc.biotype) == subset[1] - if list(subset[0].keys())[0] == "protein_coding": - assert np.all(gc.biotype == "protein_coding") diff --git a/sfaira/unit_tests/versions/test_zoo.py b/sfaira/unit_tests/versions/test_zoo.py deleted file mode 100644 index 7ab6d821d..000000000 --- a/sfaira/unit_tests/versions/test_zoo.py +++ /dev/null @@ -1,91 +0,0 @@ -import abc -import numpy as np -import os -import pandas as pd -from typing import Union -import unittest - -from sfaira.ui.model_zoo import ModelZoo, ModelZooCelltype, ModelZooEmbedding - - -class _TestZoo: - zoo: Union[ModelZoo] - data: np.ndarray - - """ - Contains functions _test* to test individual functions and attributes of estimator class. - - TODO for everybody working on this, add one _test* function in here and add it into - basic_estimator_test(). See _test_kipoi_call() for an example. - """ - - @abc.abstractmethod - def init_zoo(self): - """ - Initialise target zoo as .zoo attribute. - - :return: - """ - pass - - def simulate(self): - """ - Simulate basic data example used for unit test. - - Sets attribute .data with simulated data. - - :return: - """ - pass - - def _test_basic(self, id: str): - """ - Test all relevant model methods. - - - :return: - """ - np.random.seed(1) - self.simulate() - self.init_zoo() - self.zoo_manual.set_model_id(id) - - -class TestZooKerasEmbedding(unittest.TestCase, _TestZoo): - - def init_zoo(self): - package_dir = str(os.path.dirname(os.path.abspath(__file__))) - lookup_table = pd.read_csv( - os.path.join(package_dir, '../test_data', 'model_lookuptable.csv'), - header=0, index_col=0 - ) - self.zoo = ModelZoo(model_lookuptable=lookup_table) - self.zoo_manual = ModelZoo(model_lookuptable=None) - - def test_basic(self): - self._test_basic(id="embedding_mouse_lung_vae_theislab_0.1_0.1") - self.zoo.set_latest('mouse', 'lung', 'vae', 'theislab', '0.1') - assert self.zoo.model_id == "embedding_mouse_lung_vae_theislab_0.1_0.1" - assert self.zoo.model_id == self.zoo_manual.model_id - - -class TestZooKerasCelltype(unittest.TestCase, _TestZoo): - - def init_zoo(self): - package_dir = str(os.path.dirname(os.path.abspath(__file__))) - lookup_table = pd.read_csv( - os.path.join(package_dir, '../test_data', 'model_lookuptable.csv'), - header=0, index_col=0 - ) - self.zoo = ModelZoo(model_lookuptable=lookup_table) - self.zoo_manual = ModelZoo(model_lookuptable=None) - - def test_basic(self): - self._test_basic(id="celltype_mouse_lung_mlp_theislab_0.0.1_0.1") - self.zoo.set_latest('mouse', 'lung', 'mlp', 'theislab', '0.0.1') - assert self.zoo.model_id == "celltype_mouse_lung_mlp_theislab_0.0.1_0.1" - assert self.zoo.model_id == self.zoo_manual.model_id - - -if __name__ == '__main__': - unittest.main() diff --git a/sfaira/versions/genomes/__init__.py b/sfaira/versions/genomes/__init__.py new file mode 100644 index 000000000..afc17716d --- /dev/null +++ b/sfaira/versions/genomes/__init__.py @@ -0,0 +1,2 @@ +from .genomes import GenomeContainer, GtfInterface +from .utils import translate_id_to_symbols, translate_symbols_to_id diff --git a/sfaira/versions/genomes.py b/sfaira/versions/genomes/genomes.py similarity index 67% rename from sfaira/versions/genomes.py rename to sfaira/versions/genomes/genomes.py index 3e43babd1..0dedd4c01 100644 --- a/sfaira/versions/genomes.py +++ b/sfaira/versions/genomes/genomes.py @@ -5,12 +5,14 @@ import gzip import numpy as np import os -from typing import List, Union +from typing import Iterable, List, Union import pandas import pathlib import urllib.error import urllib.request +from sfaira.consts.directories import CACHE_DIR_GENOMES + KEY_SYMBOL = "gene_name" KEY_ID = "gene_id" KEY_TYPE = "gene_biotype" @@ -32,10 +34,9 @@ def cache_dir(self): """ The cache dir is in a cache directory in the sfaira installation that is excempt from git versioning. """ - cache_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "cache", "genomes") - cache_dir_path = pathlib.Path(cache_dir) + cache_dir_path = pathlib.Path(CACHE_DIR_GENOMES) cache_dir_path.mkdir(parents=True, exist_ok=True) - return cache_dir + return CACHE_DIR_GENOMES @property def cache_fn(self): @@ -58,11 +59,11 @@ def download_gtf_ensembl(self): Download .gtf file from ensembl FTP server and turn into reduced, gene-centric cache .csv. """ temp_file = os.path.join(self.cache_dir, self.assembly + ".gtf.gz") - print(f"downloading {self.url_ensembl_ftp} into a temporary file {temp_file}") try: _ = urllib.request.urlretrieve(url=self.url_ensembl_ftp, filename=temp_file) except urllib.error.URLError as e: - raise ValueError(f"Could not download gtf from {self.url_ensembl_ftp} with urllib.error.URLError: {e}") + raise ValueError(f"Could not download gtf from {self.url_ensembl_ftp} with urllib.error.URLError: {e}, " + f"check if assembly name '{self.assembly}' corresponds to an actual assembly.") with gzip.open(temp_file) as f: tab = pandas.read_csv(f, sep="\t", comment="#", header=None) os.remove(temp_file) # Delete temporary file .gtf.gz. @@ -88,6 +89,12 @@ def cache(self) -> pandas.DataFrame: class GenomeContainer: + """ + Container class for a genome annotation for a specific release. + + This class can be used to translate between symbols and ENSEMBL IDs for a specific assembly, to store specific gene + subsets of an assembly, and to subselect genes by biotypes in an assembly. + """ genome_tab: pandas.DataFrame assembly: str @@ -96,6 +103,16 @@ def __init__( self, assembly: str = None, ): + """ + Are you not sure which assembly to use? + + - You could use the newest one for example, check the ENSEMBL site regularly for updates: + http://ftp.ensembl.org/pub/ + - You could use one used by a specific aligner, the assemblies used by 10x cellranger are described here + for example: https://support.10xgenomics.com/single-cell-gene-expression/software/release-notes/build + + :param assembly: The full name of the genome assembly, e.g. Homo_sapiens.GRCh38.102. + """ if not isinstance(assembly, str): raise ValueError(f"supplied assembly {assembly} was not a string") self.assembly = assembly @@ -167,26 +184,43 @@ def subset( self.genome_tab = self.genome_tab.loc[subset, :].copy() @property - def symbols(self): + def symbols(self) -> List[str]: + """ + List of symbols of genes in genome container. + """ return self.genome_tab[KEY_SYMBOL].values.tolist() @property - def ensembl(self): + def ensembl(self) -> List[str]: + """ + List of ENSEMBL IDs of genes in genome container. + """ return self.genome_tab[KEY_ID].values.tolist() @property - def biotype(self): + def biotype(self) -> List[str]: + """ + List of biotypes of genes in genome container. + """ return self.genome_tab[KEY_TYPE].values.tolist() - def __validate_ensembl(self, x: List[str]): - not_found = [y for y in x if y not in self.ensembl] + def __validate_ensembl(self, x: List[str], enforce_captitalization: bool = True): + if enforce_captitalization: + not_found = [y for y in x if y not in self.ensembl] + else: + ensembl_upper = [y.upper() for y in self.ensembl] + not_found = [y for y in x if y.upper() not in ensembl_upper] if len(not_found) > 0: - raise ValueError(f"Could not find ensembl: {not_found}") - - def __validate_symbols(self, x: List[str]): - not_found = [y for y in x if y not in self.symbols] + raise ValueError(f"Could not find ENSEMBL ID: {not_found}") + + def __validate_symbols(self, x: List[str], enforce_captitalization: bool = True): + if enforce_captitalization: + not_found = [y for y in x if y not in self.symbols] + else: + symbols_upper = [y.upper() for y in self.symbols] + not_found = [y for y in x if y.upper() not in symbols_upper] if len(not_found) > 0: - raise ValueError(f"Could not find names: {not_found}") + raise ValueError(f"Could not find symbol: {not_found}") def __validate_types(self, x: List[str]): not_found = [y for y in x if y not in self.biotype] @@ -195,16 +229,57 @@ def __validate_types(self, x: List[str]): @property def n_var(self) -> int: + """ + Number of genes in genome container. + """ return self.genome_tab.shape[0] @property - def names_to_id_dict(self): + def symbol_to_id_dict(self): + """ + Dictionary-formatted map of gene symbols to ENSEMBL IDs. + """ return dict(zip(self.genome_tab[KEY_SYMBOL].values.tolist(), self.genome_tab[KEY_ID].values.tolist())) @property - def id_to_names_dict(self): + def id_to_symbols_dict(self): + """ + Dictionary-formatted map of ENSEMBL IDs to gene symbols. + """ return dict(zip(self.genome_tab[KEY_ID].values.tolist(), self.genome_tab[KEY_SYMBOL].values.tolist())) + def translate_symbols_to_id(self, x: Union[str, Iterable[str]]) -> Union[str, List[str]]: + """ + Translate gene symbols to ENSEMBL IDs. + + :param x: Symbol(s) to translate. + :return: ENSEMBL IDs + """ + if isinstance(x, str): + x = [x] + self.__validate_symbols(x=x, enforce_captitalization=False) + map_dict = dict([(k.upper(), v) for k, v in self.symbol_to_id_dict.items()]) + y = [map_dict[xx.upper()] for xx in x] + if len(y) == 1: + y = y[0] + return y + + def translate_id_to_symbols(self, x: Union[str, Iterable[str]]) -> Union[str, List[str]]: + """ + Translate ENSEMBL IDs to gene symbols. + + :param x: ENSEMBL ID(s) to translate. + :return: Gene symbols. + """ + if isinstance(x, str): + x = [x] + self.__validate_ensembl(x=x, enforce_captitalization=False) + map_dict = dict([(k.upper(), v) for k, v in self.id_to_symbols_dict.items()]) + y = [map_dict[xx.upper()] for xx in x] + if len(y) == 1: + y = y[0] + return y + @property def strippednames_to_id_dict(self): return dict(zip([i.split(".")[0] for i in self.genome_tab[KEY_SYMBOL]], diff --git a/sfaira/versions/genomes/utils.py b/sfaira/versions/genomes/utils.py new file mode 100644 index 000000000..70c65e500 --- /dev/null +++ b/sfaira/versions/genomes/utils.py @@ -0,0 +1,43 @@ +from typing import Iterable, List, Union + +from sfaira.versions.genomes import GenomeContainer + + +def translate_symbols_to_id(x: Union[str, Iterable[str]], assembly: str) -> Union[str, List[str]]: + """ + Translate gene symbols to ENSEMBL IDs. + + Input captitalization is ignored but the output capitalisation matches the ENSEMBL .gtf files. + + Are you not sure which assembly to use? + + - You could use the newest one for example, check the ENSEMBL site regularly for updates: + http://ftp.ensembl.org/pub/ + - You could use one used by a specific aligner, the assemblies used by 10x cellranger are described here + for example: https://support.10xgenomics.com/single-cell-gene-expression/software/release-notes/build + + :param x: Symbol(s) to translate. + :param assembly: The full name of the genome assembly, e.g. "Homo_sapiens.GRCh38.102". + :return: ENSEMBL IDs + """ + return GenomeContainer(assembly=assembly).translate_symbols_to_id(x=x) + + +def translate_id_to_symbols(x: Union[str, Iterable[str]], assembly: str) -> Union[str, List[str]]: + """ + Translate ENSEMBL IDs to gene symbols. + + Input captitalization is ignored but the output capitalisation matches the ENSEMBL .gtf files. + + Are you not sure which assembly to use? + + - You could use the newest one for example, check the ENSEMBL site regularly for updates: + http://ftp.ensembl.org/pub/ + - You could use one used by a specific aligner, the assemblies used by 10x cellranger are described here + for example: https://support.10xgenomics.com/single-cell-gene-expression/software/release-notes/build + + :param x: ENSEMBL ID(s) to translate. + :param assembly: The full name of the genome assembly, e.g. "Homo_sapiens.GRCh38.102". + :return: Gene symbols. + """ + return GenomeContainer(assembly=assembly).translate_id_to_symbols(x=x) diff --git a/sfaira/versions/metadata/base.py b/sfaira/versions/metadata/base.py index 73d6fc841..ae9e420c9 100644 --- a/sfaira/versions/metadata/base.py +++ b/sfaira/versions/metadata/base.py @@ -7,7 +7,7 @@ import requests from typing import Dict, List, Tuple, Union -FILE_PATH = __file__ +from sfaira.consts.directories import CACHE_DIR_ONTOLOGIES """ Ontology managament classes. @@ -26,20 +26,15 @@ """ -def get_base_ontology_cache() -> str: - folder = FILE_PATH.split(os.sep)[:-4] - folder.insert(1, os.sep) - return os.path.join(*folder, "cache", "ontologies") - - -def cached_load_obo(url, ontology_cache_dir, ontology_cache_fn): +def cached_load_obo(url, ontology_cache_dir, ontology_cache_fn, recache: bool = False): if os.name == "nt": # if running on windows, do not download obo file, but rather pass url directly to obonet + # TODO add caching option. obofile = url else: - ontology_cache_dir = os.path.join(get_base_ontology_cache(), ontology_cache_dir) + ontology_cache_dir = os.path.join(CACHE_DIR_ONTOLOGIES, ontology_cache_dir) obofile = os.path.join(ontology_cache_dir, ontology_cache_fn) # Download if necessary: - if not os.path.isfile(obofile): + if not os.path.isfile(obofile) or recache: os.makedirs(name=ontology_cache_dir, exist_ok=True) def download_obo(): @@ -53,17 +48,18 @@ def download_obo(): return obofile -def cached_load_ebi(ontology_cache_dir, ontology_cache_fn) -> (networkx.MultiDiGraph, os.PathLike): +def cached_load_ebi(ontology_cache_dir, ontology_cache_fn, recache: bool = False) -> (networkx.MultiDiGraph, os.PathLike): """ Load pickled graph object if available. :param ontology_cache_dir: :param ontology_cache_fn: + :param recache: :return: """ - ontology_cache_dir = os.path.join(get_base_ontology_cache(), ontology_cache_dir) + ontology_cache_dir = os.path.join(CACHE_DIR_ONTOLOGIES, ontology_cache_dir) picklefile = os.path.join(ontology_cache_dir, ontology_cache_fn) - if os.path.isfile(picklefile): + if os.path.isfile(picklefile) and not recache: with open(picklefile, 'rb') as f: graph = pickle.load(f) else: @@ -316,18 +312,18 @@ def map_to_leaves( """ Map a given node to leave nodes. - :param node: + :param node: Node(s) to map as symbol(s) or ID(s). :param return_type: "ids": IDs of mapped leave nodes "idx": indicies in leave note list of mapped leave nodes - :param include_self: whether to include node itself + :param include_self: DEPRECEATED. :return: """ node = self.convert_to_id(node) ancestors = self.get_ancestors(node) - if include_self: - ancestors = ancestors + [node] + # Add node itself to list of ancestors. + ancestors = ancestors + [node] if len(ancestors) > 0: ancestors = self.convert_to_id(ancestors) leaves = self.convert_to_id(self.leaves) @@ -376,6 +372,7 @@ def __init__( additional_terms: dict, additional_edges: List[Tuple[str, str]], ontology_cache_fn: str, + recache: bool, **kwargs ): def get_url_self(iri): @@ -442,7 +439,8 @@ def recursive_search(iri): edges_new.extend([(k_self, k_c) for k_c in direct_children]) return nodes_new, edges_new - graph, picklefile = cached_load_ebi(ontology_cache_dir=ontology, ontology_cache_fn=ontology_cache_fn) + graph, picklefile = cached_load_ebi(ontology_cache_dir=ontology, ontology_cache_fn=ontology_cache_fn, + recache=recache) if graph is None: self.graph = networkx.MultiDiGraph() nodes, edges = recursive_search(iri=root_term) @@ -588,12 +586,14 @@ class OntologyUberon(OntologyExtendedObo): def __init__( self, + recache: bool = False, **kwargs ): obofile = cached_load_obo( url="http://purl.obolibrary.org/obo/uberon.obo", ontology_cache_dir="uberon", ontology_cache_fn="uberon.obo", + recache=recache, ) super().__init__(obo=obofile) @@ -776,6 +776,7 @@ def __init__( self, branch: str, use_developmental_relationships: bool = False, + recache: bool = False, **kwargs ): """ @@ -790,6 +791,7 @@ def __init__( url=f"https://raw.github.com/obophenotype/cell-ontology/{branch}/cl.obo", ontology_cache_dir="cl", ontology_cache_fn=f"{branch}_cl.obo", + recache=recache, ) super().__init__(obo=obofile) @@ -857,12 +859,14 @@ class OntologyHsapdv(OntologyExtendedObo): def __init__( self, + recache: bool = False, **kwargs ): obofile = cached_load_obo( url="http://purl.obolibrary.org/obo/hsapdv.obo", ontology_cache_dir="hsapdv", ontology_cache_fn="hsapdv.obo", + recache=recache, ) super().__init__(obo=obofile) @@ -883,12 +887,14 @@ class OntologyMmusdv(OntologyExtendedObo): def __init__( self, + recache: bool = False, **kwargs ): obofile = cached_load_obo( url="http://purl.obolibrary.org/obo/mmusdv.obo", ontology_cache_dir="mmusdv", ontology_cache_fn="mmusdv.obo", + recache=recache, ) super().__init__(obo=obofile) @@ -909,12 +915,14 @@ class OntologyMondo(OntologyExtendedObo): def __init__( self, + recache: bool = False, **kwargs ): obofile = cached_load_obo( url="http://purl.obolibrary.org/obo/mondo.obo", ontology_cache_dir="mondo", ontology_cache_fn="mondo.obo", + recache=recache, ) super().__init__(obo=obofile) @@ -944,12 +952,14 @@ class OntologyCellosaurus(OntologyExtendedObo): def __init__( self, + recache: bool = False, **kwargs ): obofile = cached_load_obo( url="https://ftp.expasy.org/databases/cellosaurus/cellosaurus.obo", ontology_cache_dir="cellosaurus", ontology_cache_fn="cellosaurus.obo", + recache=recache, ) super().__init__(obo=obofile) @@ -969,7 +979,7 @@ def synonym_node_properties(self) -> List[str]: class OntologySinglecellLibraryConstruction(OntologyEbi): - def __init__(self): + def __init__(self, recache: bool = False): super().__init__( ontology="efo", root_term="EFO_0010183", @@ -981,5 +991,6 @@ def __init__(self): ("EFO:0010183", "sci-plex"), ("EFO:0010183", "sci-RNA-seq"), ], - ontology_cache_fn="efo.pickle" + ontology_cache_fn="efo.pickle", + recache=recache, ) diff --git a/sfaira/versions/topologies/class_interface.py b/sfaira/versions/topologies/class_interface.py index 6c3a26489..b7e1025f5 100644 --- a/sfaira/versions/topologies/class_interface.py +++ b/sfaira/versions/topologies/class_interface.py @@ -1,6 +1,6 @@ from typing import Union -from sfaira.versions.genomes import GenomeContainer +from sfaira.versions.genomes.genomes import GenomeContainer class TopologyContainer: From ffd8925abfdb3a9a534af317e65294753905fe7e Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Tue, 7 Sep 2021 10:47:52 +0200 Subject: [PATCH 09/15] store improvements (#346) * improvments to store API * added retrieval index sort to dask store * fixed bug in single store generator if index input was None * added sliced X and adata object emission to single store * moved memory footprint into store base class * fixed h5ad store indexing * restructured meta data streamlining code (#347) - includes bug fix that lead to missing meta data import from cellxgene structured data sets - simplified meta data streamlining code and enhanced code readability - depreceated distinction between cell type and cell type original in data set definition in favor of single attribute - allowed all ontology constrained meta data items to be supplied in any format (original + mapl, symbol, or id) via the `*_obs_col` attribute of the loader - removed resetting of _obs_col attributes in streamlining in favor of adataids controlled obs col names that extend to IDs and original labels - updated cell type entry in all data loaders * added attribute check for dictionary formatted attributes from YAML * added processing of obs columns in cellxgene import * extended error reporting in data loader discovery * fixed value protection in meta data streamlining * fixed cellxgene obs adapter * added additional mock data set with little meta data annotation * refactored cellxgene streamlining and added HANCESTRO support via EBI * fixed handling of missing ethnicity ontology for mouse * fixed EBI EFO backend * ontology unit tests now check that ontologies can be downloaded * added new generator interface, restructured batch index design interface and fixed adata uns merge in DatasetGroup (#351) - Iterators for tf dataset and similar are now emitted as an instance of a class that has an property that emit the iterator. This class keeps a pointer to the data set that is iterated over in its attributes. Thus, if this instance stays in the namespace in which tensorflow uses the iterator, it can be restarted without creating a new pointer. This had previously delayed training because tensorflow restarted the validation data set for each epoch, thus creating a new dask data set in each epoch at relatively high cost. - There is now only one iterator end point for stores (before there was base and balanced). The different index shuffling / sampling schedules are now refactored into functions and can be chosen based on string names. This makes creation and addition of new index schedules ("batch designs") easier. - Direct conversion of adata objects in memory to a store is now supported via a new multi store class. - Estimators do not have any more adata processing code but still acceppt adata, next to store instances. The adata are directly converted to a adata store instance though. All previous code related to adata processing is depreceated in the estimators. - The interface of store to estimators in the estimator is heavily simplified through the new generator interface of the store. The generator instances are placed in the train name space for efficiency but not in testing and evaluation namespaces, in which only a data set single pass is required. * Added new batch index design code - Batch schedules are now classes rather than functions. - Introduced epoch-wise reshuffling of indices in batch schedule: The reshuffling is achieved by transferring the schedule from a one-time function evaluation in the generator constructor to a evaluation of a schedule instance property that shuffles at the beginning of the iterator * Fixed balanced batch schedule. * Added merging of shared uns fields in DatasetGroup so that uns streamlining is maintained across merge of adatas. * passed empty store index validation * passed zero length index processing in batch schedule * allowed re-indexing of generator and batch schedule --- sfaira/consts/__init__.py | 3 +- sfaira/consts/adata_fields.py | 130 ++- sfaira/consts/ontologies.py | 48 +- sfaira/data/dataloaders/base/dataset.py | 872 +++++++--------- sfaira/data/dataloaders/base/dataset_group.py | 113 ++- .../databases/cellxgene/cellxgene_loader.py | 140 ++- .../dataloaders/export_adaptors/__init__.py | 1 + .../dataloaders/export_adaptors/cellxgene.py | 89 ++ ...letoflangerhans_2017_smartseq2_enge_001.py | 2 +- .../mouse_x_2018_microwellseq_han_x.py | 2 +- ...fcolon_2019_10xsequencing_kinchen_001.yaml | 2 +- ...pithelium_2019_10xsequencing_smilie_001.py | 2 +- ...man_ileum_2019_10xsequencing_martin_001.py | 2 +- ...stategland_2018_10xsequencing_henry_001.py | 2 +- .../human_pancreas_2016_indrop_baron_001.py | 2 +- ...pancreas_2016_smartseq2_segerstolpe_001.py | 2 +- ..._pancreas_2019_10xsequencing_thompson_x.py | 2 +- ...uman_lung_2020_10xsequencing_miller_001.py | 2 +- ...an_brain_2019_dropseq_polioudakis_001.yaml | 2 +- .../human_brain_2017_droncseq_habib_001.py | 2 +- ...human_testis_2018_10xsequencing_guo_001.py | 2 +- ...liver_2018_10xsequencing_macparland_001.py | 2 +- .../human_kidney_2019_droncseq_lake_001.py | 2 +- .../human_x_2019_10xsequencing_szabo_001.py | 2 +- ...man_retina_2019_10xsequencing_menon_001.py | 2 +- .../human_placenta_2018_x_ventotormo_001.py | 2 +- .../human_liver_2019_celseq2_aizarani_001.py | 2 +- ...ver_2019_10xsequencing_ramachandran_001.py | 2 +- ...an_liver_2019_10xsequencing_popescu_001.py | 2 +- ...rain_2019_10x3v2sequencing_kanton_001.yaml | 2 +- .../human_x_2020_microwellseq_han_x.py | 2 +- .../human_lung_2020_x_travaglini_001.yaml | 2 +- ...uman_colon_2020_10xsequencing_james_001.py | 2 +- .../human_lung_2019_dropseq_braga_001.py | 2 +- .../human_x_2019_10xsequencing_braga_x.py | 2 +- .../mouse_x_2019_10xsequencing_hove_001.py | 2 +- ...man_retina_2019_10xsequencing_voigt_001.py | 2 +- .../human_x_2019_10xsequencing_wang_001.py | 2 +- ...an_lung_2020_10xsequencing_lukassen_001.py | 2 +- .../human_blood_2020_10x_hao_001.yaml | 2 +- .../d10_1101_661728/mouse_x_2019_x_pisco_x.py | 2 +- ...nchyma_2020_10xsequencing_habermann_001.py | 2 +- ...n_kidney_2019_10xsequencing_stewart_001.py | 2 +- ...uman_thymus_2020_10xsequencing_park_001.py | 2 +- .../human_x_2020_scirnaseq_cao_001.yaml | 2 +- ...uman_x_2019_10xsequencing_madissoon_001.py | 2 +- ..._retina_2019_10xsequencing_lukowski_001.py | 2 +- sfaira/data/store/__init__.py | 3 +- sfaira/data/store/base.py | 107 ++ sfaira/data/store/batch_schedule.py | 127 +++ sfaira/data/store/generators.py | 376 +++++++ sfaira/data/store/io_dao.py | 5 +- sfaira/data/store/load_store.py | 34 + sfaira/data/store/multi_store.py | 156 +-- sfaira/data/store/single_store.py | 485 ++++----- sfaira/data/utils.py | 2 +- .../utils_scripts/create_target_universes.py | 2 +- .../data/utils_scripts/streamline_selected.py | 5 +- sfaira/data/utils_scripts/test_store.py | 7 +- sfaira/estimators/keras.py | 960 ++++++------------ sfaira/models/celltype/__init__.py | 1 + sfaira/models/celltype/base.py | 22 + sfaira/models/celltype/marker.py | 4 +- sfaira/models/celltype/mlp.py | 4 +- sfaira/models/embedding/__init__.py | 1 + sfaira/models/embedding/ae.py | 12 +- sfaira/models/embedding/base.py | 32 + sfaira/models/embedding/linear.py | 10 +- sfaira/models/embedding/vae.py | 4 +- sfaira/models/embedding/vaeiaf.py | 7 +- sfaira/models/embedding/vaevamp.py | 4 +- sfaira/train/summaries.py | 6 +- sfaira/train/train_model.py | 9 +- ...man_lung_2021_10xtechnology_mock1_001.yaml | 4 +- ...pancreas_2021_10xtechnology_mock2_001.yaml | 6 +- ...man_lung_2021_10xtechnology_mock3_001.yaml | 4 +- .../loaders/loaders/dno_doi_mock4/__init__.py | 1 + ...human_lung_2021_10xtechnology_mock4_001.py | 12 + ...man_lung_2021_10xtechnology_mock4_001.yaml | 52 + .../loaders/loaders/super_group.py | 2 +- .../data_for_tests/loaders/utils.py | 16 +- sfaira/unit_tests/directories.py | 10 + .../data/databases/__init__.py | 0 .../data/databases/test_database_intput.py | 65 ++ .../data/databases/test_databases_basic.py | 35 + .../data/dataset/__init__.py | 0 .../data/{ => dataset}/test_dataset.py | 37 +- .../dataset/test_meta_data_streamlining.py | 55 + .../tests_by_submodule/data/test_databases.py | 61 -- .../tests_by_submodule/data/test_store.py | 75 +- .../tests_by_submodule/estimators/__init__.py | 2 +- .../estimators/test_estimator.py | 133 ++- .../trainer/test_trainer.py | 26 +- .../ui/test_userinterface.py | 12 +- .../versions/test_ontologies.py | 29 +- sfaira/versions/genomes/genomes.py | 6 + sfaira/versions/metadata/__init__.py | 4 +- sfaira/versions/metadata/base.py | 51 +- 98 files changed, 2616 insertions(+), 1941 deletions(-) create mode 100644 sfaira/data/dataloaders/export_adaptors/__init__.py create mode 100644 sfaira/data/dataloaders/export_adaptors/cellxgene.py create mode 100644 sfaira/data/store/base.py create mode 100644 sfaira/data/store/batch_schedule.py create mode 100644 sfaira/data/store/generators.py create mode 100644 sfaira/data/store/load_store.py create mode 100644 sfaira/models/celltype/base.py create mode 100644 sfaira/models/embedding/base.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/__init__.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.py create mode 100644 sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.yaml create mode 100644 sfaira/unit_tests/tests_by_submodule/data/databases/__init__.py create mode 100644 sfaira/unit_tests/tests_by_submodule/data/databases/test_database_intput.py create mode 100644 sfaira/unit_tests/tests_by_submodule/data/databases/test_databases_basic.py create mode 100644 sfaira/unit_tests/tests_by_submodule/data/dataset/__init__.py rename sfaira/unit_tests/tests_by_submodule/data/{ => dataset}/test_dataset.py (74%) create mode 100644 sfaira/unit_tests/tests_by_submodule/data/dataset/test_meta_data_streamlining.py delete mode 100644 sfaira/unit_tests/tests_by_submodule/data/test_databases.py diff --git a/sfaira/consts/__init__.py b/sfaira/consts/__init__.py index aaaa7d3e1..cf29b89eb 100644 --- a/sfaira/consts/__init__.py +++ b/sfaira/consts/__init__.py @@ -1,4 +1,5 @@ -from sfaira.consts.adata_fields import AdataIds, AdataIdsSfaira, AdataIdsCellxgene +from sfaira.consts.adata_fields import AdataIds, AdataIdsSfaira, AdataIdsCellxgene, AdataIdsCellxgeneGeneral, \ + AdataIdsCellxgeneHuman_v1_1_0, AdataIdsCellxgeneMouse_v1_1_0 from sfaira.consts.directories import CACHE_DIR from sfaira.consts.meta_data_files import META_DATA_FIELDS from sfaira.consts.ontologies import OntologyContainerSfaira diff --git a/sfaira/consts/adata_fields.py b/sfaira/consts/adata_fields.py index e38bcdb58..19f928e02 100644 --- a/sfaira/consts/adata_fields.py +++ b/sfaira/consts/adata_fields.py @@ -12,9 +12,7 @@ class AdataIds: annotated: str assay_sc: str author: str - cell_types_original: str - cellontology_class: str - cellontology_id: str + cell_type: str development_stage: str disease: str doi_journal: str @@ -24,25 +22,32 @@ class AdataIds: dataset: str dataset_group: str ethnicity: str - gene_id_ensembl: str - gene_id_index: str - gene_id_symbols: str + feature_id: str + feature_index: str + feature_symbol: str + feature_biotype: str id: str individual: str ncells: str normalization: str organ: str organism: str + primary_data: str sample_source: str sex: str state_exact: str tech_sample: str year: str + onto_id_suffix: str + onto_original_suffix: str + load_raw: str mapped_features: str remove_gene_version: str + ontology_constrained: List[str] + obs_keys: List[str] var_keys: List[str] uns_keys: List[str] @@ -52,7 +57,7 @@ class AdataIds: classmap_target_key: str classmap_target_id_key: str - unknown_celltype_identifier: Union[str, None] + invalid_metadata_identifier: Union[str, None] not_a_cell_celltype_identifier: Union[str, None] unknown_metadata_identifier: Union[str, None] @@ -77,6 +82,9 @@ class AdataIdsSfaira(AdataIds): cell_line: str def __init__(self): + self.onto_id_suffix = "_ontology_term_id" + self.onto_original_suffix = "_original" + self.annotated = "annotated" self.assay_sc = "assay_sc" self.assay_differentiation = "assay_differentiation" @@ -84,9 +92,7 @@ def __init__(self): self.author = "author" self.bio_sample = "bio_sample" self.cell_line = "cell_line" - self.cell_types_original = "cell_types_original" - self.cellontology_class = "cell_ontology_class" - self.cellontology_id = "cell_ontology_id" + self.cell_type = "cell_type" self.default_embedding = "default_embedding" self.disease = "disease" self.doi_journal = "doi_journal" @@ -95,9 +101,10 @@ def __init__(self): self.dataset_group = "dataset_group" self.download_url_data = "download_url_data" self.download_url_meta = "download_url_meta" - self.gene_id_ensembl = "ensembl" - self.gene_id_index = self.gene_id_ensembl - self.gene_id_symbols = "names" + self.feature_id = "ensembl" + self.feature_index = self.feature_id + self.feature_symbol = "gene_symbol" + self.feature_biotype = "feature_biotype" self.id = "id" self.individual = "individual" self.ncells = "ncells" @@ -123,22 +130,28 @@ def __init__(self): self.classmap_target_key = "target" self.classmap_target_id_key = "target_id" - self.unknown_celltype_identifier = "UNKNOWN" + self.invalid_metadata_identifier = "na" self.not_a_cell_celltype_identifier = "NOT_A_CELL" self.unknown_metadata_identifier = "unknown" - self.unknown_metadata_ontology_id_identifier = "unknown" self.batch_keys = [self.bio_sample, self.individual, self.tech_sample] + self.ontology_constrained = [ + "assay_sc", + "cell_line", + "cell_type", + "development_stage", + "disease", + "ethnicity", + "organ", + ] self.obs_keys = [ "assay_sc", "assay_differentiation", "assay_type_differentiation", "bio_sample", "cell_line", - "cell_types_original", - "cellontology_class", - "cellontology_id", + "cell_type", "development_stage", "disease", "ethnicity", @@ -152,8 +165,8 @@ def __init__(self): "tech_sample", ] self.var_keys = [ - "gene_id_ensembl", - "gene_id_symbols", + "feature_id", + "feature_symbol", ] self.uns_keys = [ "annotated", @@ -181,23 +194,23 @@ class AdataIdsCellxgene(AdataIds): accepted_file_names: List[str] def __init__(self): + self.onto_id_suffix = "_ontology_term_id" + self.onto_original_suffix = "_original" + self.assay_sc = "assay" self.author = None - self.cell_types_original = "free_annotation" # TODO "free_annotation" not always given - # TODO: -> This will break streamlining though if self.cell_types_original is the same value as self.cellontology_class!! - self.cellontology_class = "cell_type" - self.cellontology_id = "cell_type_ontology_term_id" + self.cell_type = "cell_type" self.default_embedding = "default_embedding" self.doi_journal = "publication_doi" self.doi_preprint = "preprint_doi" self.disease = "disease" - self.gene_id_symbols = "index" - self.gene_id_ensembl = None # TODO not yet streamlined - self.gene_id_index = self.gene_id_symbols + self.feature_id = "ensembl" + self.feature_symbol = None self.id = "id" self.ncells = "ncells" self.organ = "tissue" self.organism = "organism" + self.primary_data = "is_primary_data" self.title = "title" self.year = "year" @@ -211,19 +224,24 @@ def __init__(self): # selected element entries used for parsing: self.author_names = "names" - self.unknown_celltype_identifier = None - self.not_a_cell_celltype_identifier = self.unknown_celltype_identifier - self.unknown_metadata_identifier = "unknown" self.invalid_metadata_identifier = "na" - self.unknown_metadata_ontology_id_identifier = "" + self.not_a_cell_celltype_identifier = "CL:0000003" + self.unknown_metadata_identifier = "unknown" self.batch_keys = [] + self.ontology_constrained = [ + "assay_sc", + "cell_type", + "development_stage", + "disease", + "ethnicity", + "organ", + ] + self.obs_keys = [ "assay_sc", - "cell_types_original", - "cellontology_class", - "cellontology_id", + "cell_type", "development_stage", "disease", "ethnicity", @@ -233,8 +251,8 @@ def __init__(self): "tech_sample", ] self.var_keys = [ - "gene_id_ensembl", - "gene_id_symbols", + "feature_id", + "feature_symbol", ] self.uns_keys = [ "doi_journal", @@ -253,3 +271,43 @@ def __init__(self): "organism", "sex", ] + + @property + def feature_index(self): + # Note this attribute is only filled in descendant classes. + return self.feature_symbol + + +class AdataIdsCellxgeneHuman_v1_1_0(AdataIdsCellxgene): + + def __init__(self): + super(AdataIdsCellxgeneHuman_v1_1_0, self).__init__() + self.feature_symbol = "hgnc_gene_symbol" + + +class AdataIdsCellxgeneMouse_v1_1_0(AdataIdsCellxgene): + + def __init__(self): + super(AdataIdsCellxgeneMouse_v1_1_0, self).__init__() + self.gene_id_symbols = "mgi_gene_symbol" + + +class AdataIdsCellxgeneGeneral(AdataIdsCellxgene): + + def __init__(self): + super(AdataIdsCellxgeneGeneral, self).__init__() + self.gene_id_symbols = "gene_symbol" + + +class AdataIdsCellxgene_v2_0_0(AdataIdsCellxgene): + + """ + https://github.com/chanzuckerberg/single-cell-curation/blob/main/schema/2.0.0/corpora_schema.md + """ + + def __init__(self): + super(AdataIdsCellxgene_v2_0_0, self).__init__() + self.feature_symbol = "feature_name" + self.feature_id = "feature_id" + self.feature_biotype = "feature_biotype" + # feature_referencec diff --git a/sfaira/consts/ontologies.py b/sfaira/consts/ontologies.py index 0dc0da996..dbf8c44b5 100644 --- a/sfaira/consts/ontologies.py +++ b/sfaira/consts/ontologies.py @@ -1,7 +1,7 @@ from typing import Dict, Union from sfaira.versions.metadata import OntologyList, OntologyCl -from sfaira.versions.metadata import OntologyCellosaurus, OntologyHsapdv, OntologyMondo, \ +from sfaira.versions.metadata import OntologyCellosaurus, OntologyHancestro, OntologyHsapdv, OntologyMondo, \ OntologyMmusdv, OntologySinglecellLibraryConstruction, OntologyUberon DEFAULT_CL = "v2021-02-01" @@ -17,8 +17,9 @@ class OntologyContainerSfaira: _assay_sc: Union[None, OntologySinglecellLibraryConstruction] _cell_line: Union[None, OntologyCellosaurus] - _cellontology_class: Union[None, OntologyCl] + _cell_type: Union[None, OntologyCl] _development_stage: Union[None, Dict[str, Union[OntologyHsapdv, OntologyMmusdv]]] + _ethnicity: Union[None, Dict[str, Union[OntologyHancestro, None]]] _organ: Union[None, OntologyUberon] def __init__(self): @@ -29,8 +30,7 @@ def __init__(self): self.assay_type_differentiation = OntologyList(terms=["guided", "unguided"]) self.bio_sample = None self._cell_line = None - self._cellontology_class = None - self.cell_types_original = None + self._cell_type = None self.collection_id = None self.default_embedding = None self._development_stage = None @@ -39,10 +39,7 @@ def __init__(self): self.doi_main = None self.doi_journal = None self.doi_preprint = None - self.ethnicity = { - "human": None, - "mouse": None, - } + self._ethnicity = None self.id = None self.individual = None self.normalization = None @@ -63,9 +60,19 @@ def reload_ontology(self, attr): elif attr == "cell_line": self._cell_line = OntologyCellosaurus(**kwargs) elif attr == "cellontology_class": - self._cellontology_class = OntologyCl(branch=DEFAULT_CL, **kwargs) + self._cell_type = OntologyCl(branch=DEFAULT_CL, **kwargs) + elif attr == "development_stage": + self._development_stage = { + "human": OntologyHsapdv(), + "mouse": OntologyMmusdv(), + } elif attr == "disease": self._disease = OntologyMondo(**kwargs) + elif attr == "ethnicity": + self._ethnicity = { + "human": OntologyHancestro(), + "mouse": None, + } elif attr == "organ": self._organ = OntologyUberon(**kwargs) return self._assay_sc @@ -83,14 +90,14 @@ def cell_line(self): return self._cell_line @property - def cellontology_class(self): - if self._cellontology_class is None: - self._cellontology_class = OntologyCl(branch=DEFAULT_CL) - return self._cellontology_class + def cell_type(self): + if self._cell_type is None: + self._cell_type = OntologyCl(branch=DEFAULT_CL) + return self._cell_type - @cellontology_class.setter - def cellontology_class(self, x: str): - self._cellontology_class = OntologyCl(branch=x) + @cell_type.setter + def cell_type(self, x: str): + self._cell_type = OntologyCl(branch=x) @property def development_stage(self): @@ -107,6 +114,15 @@ def disease(self): self._disease = OntologyMondo() return self._disease + @property + def ethnicity(self): + if self._ethnicity is None: + self._ethnicity = { + "human": OntologyHancestro(), + "mouse": None, + } + return self._ethnicity + @property def organ(self): if self._organ is None: diff --git a/sfaira/data/dataloaders/base/dataset.py b/sfaira/data/dataloaders/base/dataset.py index cd70e2c2f..5aa0b7adc 100644 --- a/sfaira/data/dataloaders/base/dataset.py +++ b/sfaira/data/dataloaders/base/dataset.py @@ -20,9 +20,11 @@ from sfaira.versions.genomes import GenomeContainer from sfaira.versions.metadata import Ontology, OntologyHierarchical, CelltypeUniverse -from sfaira.consts import AdataIds, AdataIdsCellxgene, AdataIdsSfaira, META_DATA_FIELDS, OCS +from sfaira.consts import AdataIds, AdataIdsCellxgeneGeneral, AdataIdsCellxgeneHuman_v1_1_0, AdataIdsCellxgeneMouse_v1_1_0, \ + AdataIdsSfaira, META_DATA_FIELDS, OCS +from sfaira.data.dataloaders.export_adaptors import cellxgene_export_adaptor from sfaira.data.store.io_dao import write_dao -from sfaira.data.dataloaders.base.utils import is_child, clean_string, get_directory_formatted_doi +from sfaira.data.dataloaders.base.utils import is_child, get_directory_formatted_doi from sfaira.data.utils import collapse_matrix, read_yaml from sfaira.consts.utils import clean_id_str @@ -53,6 +55,7 @@ class DatasetBase(abc.ABC): _author: Union[None, str] _bio_sample: Union[None, str] _cell_line: Union[None, str] + _cell_type: Union[None, str] _default_embedding: Union[None, str] _development_stage: Union[None, str] _disease: Union[None, str] @@ -76,27 +79,25 @@ class DatasetBase(abc.ABC): _bio_sample: Union[None, str] _year: Union[None, int] - _assay_sc_obs_key: Union[None, str] - _assay_differentiation_obs_key: Union[None, str] - _assay_type_differentiation_obs_key: Union[None, str] - _assay_cell_line_obs_key: Union[None, str] - _cellontology_class_obs_key: Union[None, str] - _cellontology_id_obs_key: Union[None, str] - _cell_types_original_obs_key: Union[None, str] - _development_stage_obs_key: Union[None, str] - _disease_obs_key: Union[None, str] - _ethnicity_obs_key: Union[None, str] - _individual: Union[None, str] - _organ_obs_key: Union[None, str] - _organism_obs_key: Union[None, str] - _bio_sample_obs_key: Union[None, str] - _sample_source_obs_key: Union[None, str] - _sex_obs_key: Union[None, str] - _state_exact_obs_key: Union[None, str] - _tech_sample_obs_key: Union[None, str] - - _gene_id_symbols_var_key: Union[None, str] - _gene_id_ensembl_var_key: Union[None, str] + assay_sc_obs_key: Union[None, str] + assay_differentiation_obs_key: Union[None, str] + assay_type_differentiation_obs_key: Union[None, str] + assay_cell_line_obs_key: Union[None, str] + bio_sample_obs_key: Union[None, str] + cell_type_obs_key: Union[None, str] + development_stage_obs_key: Union[None, str] + disease_obs_key: Union[None, str] + ethnicity_obs_key: Union[None, str] + individual_obs_key: Union[None, str] + organ_obs_key: Union[None, str] + organism_obs_key: Union[None, str] + sample_source_obs_key: Union[None, str] + sex_obs_key: Union[None, str] + state_exact_obs_key: Union[None, str] + tech_sample_obs_key: Union[None, str] + + gene_id_symbols_var_key: Union[None, str] + gene_id_ensembl_var_key: Union[None, str] _celltype_universe: Union[None, CelltypeUniverse] _ontology_class_map: Union[None, dict] @@ -161,6 +162,7 @@ def __init__( self._assay_type_differentiation = None self._bio_sample = None self._cell_line = None + self._cell_type = None self._default_embedding = None self._development_stage = None self._disease = None @@ -184,31 +186,27 @@ def __init__( self._title = None self._year = None - self._assay_sc_obs_key = None - self._assay_differentiation_obs_key = None - self._assay_type_differentiation_obs_key = None - self._bio_sample_obs_key = None - self._cell_line_obs_key = None - self._cellontology_class_obs_key = None - self._cellontology_id_obs_key = None - self._cell_types_original_obs_key = None - self._development_stage_obs_key = None - self._disease_obs_key = None - self._ethnicity_obs_key = None - - self._individual_obs_key = None - self._organ_obs_key = None - self._organism_obs_key = None - self._sample_source_obs_key = None - self._sex_obs_key = None - self._state_exact_obs_key = None - self._tech_sample_obs_key = None - - self._gene_id_symbols_var_key = None - self._gene_id_ensembl_var_key = None + self.assay_sc_obs_key = None + self.assay_differentiation_obs_key = None + self.assay_type_differentiation_obs_key = None + self.bio_sample_obs_key = None + self.cell_line_obs_key = None + self.cell_type_obs_key = None + self.development_stage_obs_key = None + self.disease_obs_key = None + self.ethnicity_obs_key = None + self.individual_obs_key = None + self.organ_obs_key = None + self.organism_obs_key = None + self.sample_source_obs_key = None + self.sex_obs_key = None + self.state_exact_obs_key = None + self.tech_sample_obs_key = None + + self.gene_id_symbols_var_key = None + self.gene_id_ensembl_var_key = None self.class_maps = {"0": {}} - self._unknown_celltype_identifiers = self._adata_ids.unknown_celltype_identifier self._celltype_universe = None self._ontology_class_map = None @@ -234,12 +232,14 @@ def __init__( if v is not None and k not in ["organism", "sample_fns", "dataset_index"]: if isinstance(v, dict): # v is a dictionary over file-wise meta-data items assert self.sample_fn in v.keys(), f"did not find key {self.sample_fn} in yamls keys for {k}" - setattr(self, k, v[self.sample_fn]) - else: # v is a meta-data item - try: - setattr(self, k, v) - except AttributeError as e: - raise ValueError(f"An error occured when setting {k} as {v}: {e}") + v = v[self.sample_fn] + # Catches spelling errors in meta data definition (yaml keys). + if not hasattr(self, k) and not hasattr(self, "_" + k): + raise ValueError(f"Tried setting unavailable property {k}.") + try: + setattr(self, k, v) + except AttributeError as e: + raise ValueError(f"An error occured when setting {k} as {v}: {e}") # ID can be set now already because YAML was used as input instead of child class constructor. self.set_dataset_id(idx=yaml_vals["meta"]["dataset_index"]) @@ -448,31 +448,19 @@ def _add_missing_featurenames( self, match_to_reference: Union[str, bool, None], ): - # If schema does not include symbols or ensebl ids, add them to the schema so we can do the conversion - if hasattr(self._adata_ids, "gene_id_symbols"): - gene_id_symbols = self._adata_ids.gene_id_symbols - else: - gene_id_symbols = "gene_symbol" # add some default name if not in schema - self._adata_ids.gene_id_symbols = gene_id_symbols - if hasattr(self._adata_ids, "gene_id_ensembl"): - gene_id_ensembl = self._adata_ids.gene_id_ensembl - else: - gene_id_ensembl = "ensembl" # add some default name if not in schema - self._adata_ids.gene_id_ensembl = gene_id_ensembl - - if not self.gene_id_symbols_var_key and not self.gene_id_ensembl_var_key: + if self.gene_id_symbols_var_key is None and self.gene_id_ensembl_var_key is None: raise ValueError("Either gene_id_symbols_var_key or gene_id_ensembl_var_key needs to be provided in the" " dataloader") - elif not self.gene_id_symbols_var_key and self.gene_id_ensembl_var_key: + elif self.gene_id_symbols_var_key is None and self.gene_id_ensembl_var_key: # Convert ensembl ids to gene symbols id_dict = self.genome_container.id_to_symbols_dict ensids = self.adata.var.index if self.gene_id_ensembl_var_key == "index" else self.adata.var[self.gene_id_ensembl_var_key] - self.adata.var[gene_id_symbols] = [ + self.adata.var[self._adata_ids.feature_symbol] = [ id_dict[n.split(".")[0]] if n.split(".")[0] in id_dict.keys() else 'n/a' for n in ensids ] - self.gene_id_symbols_var_key = gene_id_symbols - elif self.gene_id_symbols_var_key and not self.gene_id_ensembl_var_key: + self.gene_id_symbols_var_key = self._adata_ids.feature_symbol + elif self.gene_id_symbols_var_key and self.gene_id_ensembl_var_key is None: # Convert gene symbols to ensembl ids id_dict = self.genome_container.symbol_to_id_dict id_strip_dict = self.genome_container.strippednames_to_id_dict @@ -489,8 +477,8 @@ def _add_missing_featurenames( ensids.append(id_strip_dict[n.split(".")[0]]) else: ensids.append('n/a') - self.adata.var[gene_id_ensembl] = ensids - self.gene_id_ensembl_var_key = gene_id_ensembl + self.adata.var[self._adata_ids.feature_id] = ensids + self.gene_id_ensembl_var_key = self._adata_ids.feature_id def _collapse_ensembl_gene_id_versions(self): """ @@ -649,28 +637,56 @@ def streamline_metadata( clean_var: bool = True, clean_uns: bool = True, clean_obs_names: bool = True, + keep_orginal_obs: bool = False, + keep_symbol_obs: bool = True, + keep_id_obs: bool = True, ): """ Streamline the adata instance to a defined output schema. Output format are saved in ADATA_FIELDS* classes. + Note on ontology-controlled meta data: + These are defined for a given format in `ADATA_FIELDS*.ontology_constrained`. + They may appear in three different formats: + - original (free text) annotation + - ontology symbol + - ontology ID + During streamlining, these ontology-controlled meta data are projected to all of these three different formats. + The initially annotated column may be any of these and is defined as "{attr}_obs_col". + The resulting three column per meta data item are named: + - ontology symbol: "{ADATA_FIELDS*.attr}" + - ontology ID: {ADATA_FIELDS*.attr}_{ADATA_FIELDS*.onto_id_suffix}" + - original (free text) annotation: "{ADATA_FIELDS*.attr}_{ADATA_FIELDS*.onto_original_suffix}" + :param schema: Export format. - "sfaira" - "cellxgene" :param clean_obs: Whether to delete non-streamlined fields in .obs, .obsm and .obsp. :param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp. :param clean_uns: Whether to delete non-streamlined fields in .uns. - :param clean_obs_names: Whether to replace obs_names with a string comprised of dataset id and an increasing integer. + :param clean_obs_names: Whether to replace obs_names with a string comprised of dataset id and an increasing + integer. + :param keep_orginal_obs: For ontology-constrained .obs columns, whether to keep a column with original + annotation. + :param keep_symbol_obs: For ontology-constrained .obs columns, whether to keep a column with ontology symbol + annotation. + :param keep_id_obs: For ontology-constrained .obs columns, whether to keep a column with ontology ID annotation. :return: """ + schema_version = schema.split(":")[-1] if ":" in schema else None self.__assert_loaded() # Set schema as provided by the user - if schema == "sfaira": + if schema.startswith("sfaira"): adata_target_ids = AdataIdsSfaira() - elif schema == "cellxgene": - adata_target_ids = AdataIdsCellxgene() + elif schema.startswith("cellxgene"): + if self.organism == "human": + adata_target_ids = AdataIdsCellxgeneHuman_v1_1_0() + elif self.organism == "human": + adata_target_ids = AdataIdsCellxgeneHuman_v1_1_0() + else: + adata_target_ids = AdataIdsCellxgeneGeneral() else: raise ValueError(f"did not recognize schema {schema}") @@ -682,9 +698,9 @@ def streamline_metadata( # Creating new var annotation var_new = pd.DataFrame() for k in adata_target_ids.var_keys: - if k == "gene_id_ensembl": + if k == "feature_id": if not self.gene_id_ensembl_var_key: - raise ValueError("gene_id_ensembl_var_key not set in dataloader despite being required by the " + raise ValueError("feature_id not set in dataloader despite being required by the " "selected meta data schema. please run streamline_features() first to create the " "missing annotation") elif self.gene_id_ensembl_var_key == "index": @@ -693,7 +709,7 @@ def streamline_metadata( var_new[getattr(adata_target_ids, k)] = self.adata.var[self.gene_id_ensembl_var_key].tolist() del self.adata.var[self.gene_id_ensembl_var_key] self.gene_id_ensembl_var_key = getattr(adata_target_ids, k) - elif k == "gene_id_symbols": + elif k == "feature_symbol": if not self.gene_id_symbols_var_key: raise ValueError("gene_id_symbols_var_key not set in dataloader despite being required by the " "selected meta data schema. please run streamline_features() first to create the " @@ -710,7 +726,25 @@ def streamline_metadata( val = val[0] var_new[getattr(adata_target_ids, k)] = val # set var index - var_new.index = var_new[adata_target_ids.gene_id_index].tolist() + var_new.index = var_new[adata_target_ids.feature_index].tolist() + if clean_var: + if self.adata.varm is not None: + del self.adata.varm + if self.adata.varp is not None: + del self.adata.varp + self.adata.var = var_new + if "feature_id" not in adata_target_ids.var_keys: + self.gene_id_ensembl_var_key = None + if "feature_symbol" not in adata_target_ids.var_keys: + self.gene_id_symbols_var_key = None + else: + index_old = self.adata.var.index.copy() + # Add old columns in if they are not duplicated: + self.adata.var = pd.concat([ + var_new, + pd.DataFrame(dict([(k, v) for k, v in self.adata.var.items() if k not in var_new.columns])) + ], axis=1) + self.adata.var.index = index_old # Prepare new .uns dict: uns_new = {} @@ -719,20 +753,46 @@ def streamline_metadata( val = getattr(self, k) elif hasattr(self, f"{k}_obs_key") and getattr(self, f"{k}_obs_key") is not None: val = np.sort(np.unique(self.adata.obs[getattr(self, f"{k}_obs_key")].values)).tolist() + elif getattr(self._adata_ids, k) in self.adata.obs.columns: + val = np.sort(np.unique(self.adata.obs[getattr(self._adata_ids, k)].values)).tolist() else: val = None while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: # Unpack nested lists/tuples. val = val[0] uns_new[getattr(adata_target_ids, k)] = val + if clean_uns: + self.adata.uns = uns_new + else: + self.adata.uns.update(uns_new) # Prepare new .obs dataframe - per_cell_labels = ["cell_types_original", "cellontology_class", "cellontology_id"] + # Queried meta data may be: + # 1) in .obs + # a) for an ontology-constrained meta data item + # I) as free annotation with a term map to an ontology + # II) as column with ontology symbols + # III) as column with ontology IDs + # b) for a non-ontology-constrained meta data item: + # I) as free annotation + # 2) in .uns + # b) as elements that are ontology symbols + # c) as elements that are ontology IDs + # .obs annotation takes priority over .uns annotation if both are present. + # The output columns are: + # - for an ontology-constrained meta data item "attr": + # * symbols: "attr" + # * IDs: "attr" + self._adata_ids.onto_id_suffix + # * original labels: "attr" + self._adata_ids.onto_original_suffix + # - for a non-ontology-constrained meta data item "attr": + # * original labels: "attr" + self._adata_ids.onto_original_suffix obs_new = pd.DataFrame(index=self.adata.obs.index) - # Handle non-cell type labels: - for k in [x for x in adata_target_ids.obs_keys if x not in per_cell_labels]: - # Handle batch-annotation columns which can be provided as a combination of columns separated by an asterisk + for k in [x for x in adata_target_ids.obs_keys]: if k in experiment_batch_labels and getattr(self, f"{k}_obs_key") is not None and \ "*" in getattr(self, f"{k}_obs_key"): + # Handle batch-annotation columns which can be provided as a combination of columns separated by an + # asterisk. + # The queried meta data are always: + # 1b-I) a combination of existing columns in .obs old_cols = getattr(self, f"{k}_obs_key") batch_cols = [] for batch_col in old_cols.split("*"): @@ -749,13 +809,15 @@ def streamline_metadata( for xx in zip(*[self.adata.obs[batch_col].values.tolist() for batch_col in batch_cols]) ] else: + # Locate annotation. if hasattr(self, f"{k}_obs_key") and getattr(self, f"{k}_obs_key") is not None and \ getattr(self, f"{k}_obs_key") in self.adata.obs.columns: # Last and-clause to check if this column is included in data sets. This may be violated if data # is obtained from a database which is not fully streamlined. - old_col = getattr(self, f"{k}_obs_key") - val = self.adata.obs[old_col].values.tolist() + # Look for 1a-* and 1b-I + val = self.adata.obs[getattr(self, f"{k}_obs_key")].values.tolist() else: + # Look for 2a, 2b val = getattr(self, k) if val is None: val = self._adata_ids.unknown_metadata_identifier @@ -763,56 +825,37 @@ def streamline_metadata( while hasattr(val, '__len__') and not isinstance(val, str) and len(val) == 1: val = val[0] val = [val] * self.adata.n_obs - new_col = getattr(adata_target_ids, k) + # Identify annotation: disambiguate 1a-I, 1a-II, 1a-III, 1b-I. + if k in self._adata_ids.ontology_constrained: + # 1a-*. + if isinstance(self.get_ontology(k=k), OntologyHierarchical) and np.all([ + self.get_ontology(k=k).is_a_node_name(x) or x == self._adata_ids.unknown_metadata_identifier + for x in np.unique(val) + ]): # 1a-II) + new_col = getattr(adata_target_ids, k) + validation_ontology = self.get_ontology(k=k) + elif isinstance(self.get_ontology(k=k), OntologyHierarchical) and np.all([ + self.get_ontology(k=k).is_a_node_id(x) or x == self._adata_ids.unknown_metadata_identifier + for x in np.unique(val) + ]): # 1a-III) + new_col = getattr(adata_target_ids, k) + self._adata_ids.onto_id_suffix + validation_ontology = None + else: # 1a-I) + new_col = getattr(adata_target_ids, k) + self._adata_ids.onto_original_suffix + validation_ontology = None + else: + # 1b-I. + new_col = getattr(adata_target_ids, k) + validation_ontology = self.get_ontology(k=k) # Check values for validity: - ontology = getattr(self.ontology_container_sfaira, k) \ - if hasattr(self.ontology_container_sfaira, k) else None - if k == "development_stage": - ontology = ontology[self.organism] - if k == "ethnicity": - ontology = ontology[self.organism] - self._value_protection(attr=new_col, allowed=ontology, attempted=[ + self._value_protection(attr=new_col, allowed=validation_ontology, attempted=[ x for x in np.unique(val) if x not in [ self._adata_ids.unknown_metadata_identifier, - self._adata_ids.unknown_metadata_ontology_id_identifier, ] ]) obs_new[new_col] = val - setattr(self, f"{k}_obs_key", new_col) - # Set cell types: - # Build auxilliary table with cell type information: - if self.cell_types_original_obs_key is not None: - obs_cl = self.project_celltypes_to_ontology(copy=True, adata_fields=self._adata_ids) - else: - obs_cl = pd.DataFrame({ - self._adata_ids.cellontology_class: [self._adata_ids.unknown_metadata_identifier] * self.adata.n_obs, - self._adata_ids.cellontology_id: [self._adata_ids.unknown_metadata_identifier] * self.adata.n_obs, - self._adata_ids.cell_types_original: [self._adata_ids.unknown_metadata_identifier] * self.adata.n_obs, - }, index=self.adata.obs.index) - for k in [x for x in per_cell_labels if x in adata_target_ids.obs_keys]: - obs_new[getattr(adata_target_ids, k)] = obs_cl[getattr(self._adata_ids, k)] - del obs_cl - - # Add new annotation to adata and delete old fields if requested - if clean_var: - if self.adata.varm is not None: - del self.adata.varm - if self.adata.varp is not None: - del self.adata.varp - self.adata.var = var_new - if "gene_id_ensembl" not in adata_target_ids.var_keys: - self.gene_id_ensembl_var_key = None - if "gene_id_symbols" not in adata_target_ids.var_keys: - self.gene_id_symbols_var_key = None - else: - index_old = self.adata.var.index.copy() - # Add old columns in if they are not duplicated: - self.adata.var = pd.concat([ - var_new, - pd.DataFrame(dict([(k, v) for k, v in self.adata.var.items() if k not in var_new.columns])) - ], axis=1) - self.adata.var.index = index_old + # For ontology-constrained meta data, the remaining columns are added after .obs cleaning below. if clean_obs: if self.adata.obsm is not None: del self.adata.obsm @@ -829,12 +872,21 @@ def streamline_metadata( if k not in adata_target_ids.controlled_meta_keys])) ], axis=1) self.adata.obs.index = index_old + for k in [x for x in adata_target_ids.obs_keys if x in adata_target_ids.ontology_constrained]: + # Add remaining output columns for ontology-constrained meta data. + self.__impute_ontology_cols_obs(attr=k, adata_ids=adata_target_ids) + # Delete attribute-specific columns that are not desired. + col_name = getattr(self._adata_ids, k) + self._adata_ids.onto_id_suffix + if not keep_id_obs and col_name in self.adata.obs.columns: + del self.adata.obs[col_name] + col_name = getattr(self._adata_ids, k) + self._adata_ids.onto_original_suffix + if not keep_orginal_obs and col_name in self.adata.obs.columns: + del self.adata.obs[col_name] + col_name = getattr(self._adata_ids, k) + if not keep_symbol_obs and col_name in self.adata.obs.columns: + del self.adata.obs[col_name] if clean_obs_names: self.adata.obs.index = [f"{self.id}_{i}" for i in range(1, self.adata.n_obs + 1)] - if clean_uns: - self.adata.uns = uns_new - else: - self.adata.uns = {**self.adata.uns, **uns_new} # Make sure that correct unknown_metadata_identifier is used in .uns, .obs and .var metadata unknown_old = self._adata_ids.unknown_metadata_identifier @@ -847,90 +899,11 @@ def streamline_metadata( if self.adata.uns[k] is None or self.adata.uns[k] == unknown_old: self.adata.uns[k] = unknown_new - # Add additional hard-coded description changes for cellxgene schema: - if schema == "cellxgene": - self.adata.uns["layer_descriptions"] = {"X": "raw"} - self.adata.uns["version"] = { - "corpora_encoding_version": "0.1.0", - "corpora_schema_version": "1.1.0", - } - self.adata.uns["contributors"] = { - "name": "sfaira", - "email": "https://github.com/theislab/sfaira/issues", - "institution": "sfaira", - } - # TODO port this into organism ontology handling. - if self.organism == "mouse": - self.adata.uns["organism"] = "Mus musculus" - self.adata.uns["organism_ontology_term_id"] = "NCBITaxon:10090" - elif self.organism == "human": - self.adata.uns["organism"] = "Homo sapiens" - self.adata.uns["organism_ontology_term_id"] = "NCBITaxon:9606" - else: - raise ValueError(f"organism {self.organism} currently not supported by cellxgene schema") - # Add ontology IDs where necessary (note that human readable terms are also kept): - ontology_cols = ["organ", "assay_sc", "disease", "ethnicity", "development_stage"] - non_ontology_cols = ["sex"] - for k in ontology_cols: - # TODO enable ethinicity once the distinction between ontology for human and None for mouse works. - if getattr(adata_target_ids, k) in self.adata.obs.columns and k != "ethnicity": - ontology = getattr(self.ontology_container_sfaira, k) - # Disambiguate organism-dependent ontologies: - if isinstance(ontology, dict): - ontology = ontology[self.organism] - self.__project_name_to_id_obs( - ontology=ontology, - key_in=getattr(adata_target_ids, k), - key_out=getattr(adata_target_ids, k) + "_ontology_term_id", - map_exceptions=[adata_target_ids.unknown_metadata_identifier], - map_exceptions_value=adata_target_ids.unknown_metadata_ontology_id_identifier, - ) - else: - self.adata.obs[getattr(adata_target_ids, k)] = adata_target_ids.unknown_metadata_identifier - self.adata.obs[getattr(adata_target_ids, k) + "_ontology_term_id"] = \ - adata_target_ids.unknown_metadata_ontology_id_identifier - # Correct unknown cell type entries: - self.adata.obs[getattr(adata_target_ids, "cellontology_class")] = [ - x if x not in [self._adata_ids.unknown_celltype_identifier, - self._adata_ids.not_a_cell_celltype_identifier] - else "native cell" - for x in self.adata.obs[getattr(adata_target_ids, "cellontology_class")]] - self.adata.obs[getattr(adata_target_ids, "cellontology_id")] = [ - x if x not in [self._adata_ids.unknown_celltype_identifier, - self._adata_ids.not_a_cell_celltype_identifier] - else "CL:0000003" - for x in self.adata.obs[getattr(adata_target_ids, "cellontology_id")]] - # Reorder data frame to put ontology columns first: - cellxgene_cols = [getattr(adata_target_ids, x) for x in ontology_cols] + \ - [getattr(adata_target_ids, x) for x in non_ontology_cols] + \ - [getattr(adata_target_ids, x) + "_ontology_term_id" for x in ontology_cols] - self.adata.obs = self.adata.obs[ - cellxgene_cols + [x for x in self.adata.obs.columns if x not in cellxgene_cols] - ] - # Adapt var columns naming. - if self.organism == "human": - gene_id_new = "hgnc_gene_symbol" - elif self.organism == "mouse": - gene_id_new = "mgi_gene_symbol" - else: - raise ValueError(f"organism {self.organism} currently not supported") - self.adata.var[gene_id_new] = self.adata.var[getattr(adata_target_ids, "gene_id_symbols")] - self.adata.var.index = self.adata.var[gene_id_new].tolist() - if gene_id_new != self.gene_id_symbols_var_key: - del self.adata.var[self.gene_id_symbols_var_key] - self.gene_id_symbols_var_key = gene_id_new - # Check if .X is counts: The conversion are based on the assumption that .X is csr. - assert isinstance(self.adata.X, scipy.sparse.csr_matrix), type(self.adata.X) - count_values = np.unique(np.asarray(self.adata.X.todense())) - if not np.all(count_values % 1. == 0.): - print(f"WARNING: not all count entries were counts, " - f"the maximum deviation from integer is " - f"{np.max([x % 1. if x % 1. < 0.5 else 1. - x % 1. for x in count_values])}. " - f"The count matrix is rounded.") - self.adata.X.data = np.rint(self.adata.X.data) - self._adata_ids = adata_target_ids # set new adata fields to class after conversion self.streamlined_meta = True + # Add additional hard-coded description changes for cellxgene schema: + if schema.startswith("cellxgene"): + self.adata = cellxgene_export_adaptor(adata=self.adata, adata_ids=self._adata_ids, version=schema_version) def write_distributed_store( self, @@ -1066,11 +1039,29 @@ def _set_genome(self, assembly: Union[str, None]): def doi_cleaned_id(self): return "_".join(self.id.split("_")[:-1]) + def get_ontology(self, k) -> OntologyHierarchical: + x = getattr(self.ontology_container_sfaira, k) if hasattr(self.ontology_container_sfaira, k) else None + if isinstance(x, dict): + assert isinstance(self.organism, str) + x = x[self.organism] + return x + @property def fn_ontology_class_map_tsv(self): """Standardised file name under which cell type conversion tables are saved.""" return self.doi_cleaned_id + ".tsv" + def _write_ontology_class_map(self, fn, tab: pd.DataFrame): + """ + Write class map to file. + + Helper to allow direct interaction with written table instead of using table from instance. + + :param fn: File name of csv to write class maps to. + :param tab: Class map table. + """ + tab.to_csv(fn, index=False, sep="\t") + def write_ontology_class_map( self, fn, @@ -1085,34 +1076,26 @@ def write_ontology_class_map( :return: """ if not self.annotated: - warnings.warn(f"attempted to write ontology classmaps for data set {self.id} without annotation") + warnings.warn(f"attempted to write ontology class maps for data set {self.id} without annotation") else: - labels_original = np.sort(np.unique(self.adata.obs[self.cell_types_original_obs_key].values)) + labels_original = np.sort(np.unique(self.adata.obs[self.cell_type_obs_key].values)) tab = self.celltypes_universe.prepare_celltype_map_tab( source=labels_original, match_only=False, anatomical_constraint=self.organ, include_synonyms=True, - omit_list=self._unknown_celltype_identifiers, + omit_list=[self._adata_ids.unknown_metadata_identifier], **kwargs ) if not os.path.exists(fn) or not protected_writing: - self._write_class_map(fn=fn, tab=tab) + self._write_ontology_class_map(fn=fn, tab=tab) - def _write_class_map(self, fn, tab): - """ - Write class map. - - :param fn: File name of csv to write class maps to. - :param tab: Table to write - :return: - """ - tab.to_csv(fn, index=False, sep="\t") - - def _read_class_map(self, fn) -> pd.DataFrame: + def _read_ontology_class_map(self, fn) -> pd.DataFrame: """ Read class map. + Helper to allow direct interaction with resulting table instead of loading into instance. + :param fn: File name of csv to load class maps from. :return: """ @@ -1124,7 +1107,7 @@ def _read_class_map(self, fn) -> pd.DataFrame: raise pandas.errors.ParserError(e) return tab - def load_ontology_class_map(self, fn): + def read_ontology_class_map(self, fn): """ Load class maps of free text cell types to ontology classes. @@ -1132,130 +1115,191 @@ def load_ontology_class_map(self, fn): :return: """ if os.path.exists(fn): - self.cell_ontology_map = self._read_class_map(fn=fn) + self.cell_type_map = self._read_ontology_class_map(fn=fn) else: - if self.cell_types_original_obs_key is not None: - warnings.warn(f"file {fn} does not exist but cell_types_original_obs_key is given") + if self.cell_type_obs_key is not None: + warnings.warn(f"file {fn} does not exist but cell_type_obs_key {self.cell_type_obs_key} is given") - def project_celltypes_to_ontology(self, adata_fields: Union[AdataIds, None] = None, copy=False, update_fields=True): + def project_free_to_ontology(self, attr: str, copy: bool = False): """ Project free text cell type names to ontology based on mapping table. ToDo: add ontology ID setting here. + ToDo: only for cell type right now, extend to other meta data in the future. - :param adata_fields: AdataIds instance that holds the column names to use for the annotation :param copy: If True, a dataframe with the celltype annotation is returned, otherwise self.adata.obs is updated inplace. - :param update_fields: If True, the celltype-related attributes of this Dataset instance are updated. Basically, - this should always be true, unless self.adata.obs is not updated by (or with the output of) this function. - This includes the following fields: self.cellontology_class_obs_key, self.cell_types_original_obs_key, - self.cellontology_id_obs_key :return: """ - assert copy or update_fields, "when copy is set to False, update_fields cannot be False" - - adata_fields = adata_fields if adata_fields is not None else self._adata_ids + ontology_map = attr + "_map" + if hasattr(self, ontology_map): + ontology_map = getattr(self, ontology_map) + else: + ontology_map = None + print(f"WARNING: did not find ontology map for {attr} which was only defined by free annotation") + adata_fields = self._adata_ids results = {} - labels_original = self.adata.obs[self.cell_types_original_obs_key].values - if self.cell_ontology_map is not None: # only if this was defined + col_original = attr + adata_fields.onto_original_suffix + labels_original = self.adata.obs[col_original].values + if ontology_map is not None: # only if this was defined labels_mapped = [ - self.cell_ontology_map[x] if x in self.cell_ontology_map.keys() + ontology_map[x] if x in ontology_map.keys() else x for x in labels_original ] # Convert unknown celltype placeholders (needs to be hardcoded here as placeholders are also hardcoded in # conversion tsv files placeholder_conversion = { - "UNKNOWN": adata_fields.unknown_celltype_identifier, + "UNKNOWN": adata_fields.unknown_metadata_identifier, "NOT_A_CELL": adata_fields.not_a_cell_celltype_identifier, } labels_mapped = [ placeholder_conversion[x] if x in placeholder_conversion.keys() else x for x in labels_mapped ] + map_exceptions = [adata_fields.unknown_metadata_identifier] + if attr == "cell_type": + map_exceptions.append(adata_fields.not_a_cell_celltype_identifier) # Validate mapped IDs based on ontology: # This aborts with a readable error if there was a target in the mapping file that doesnt match the ontology # This protection blocks progression in the unit test if not deactivated. self._value_protection( - attr="celltypes", - allowed=self.ontology_celltypes, - attempted=[ - x for x in list(set(labels_mapped)) - if x not in [ - adata_fields.unknown_celltype_identifier, - adata_fields.not_a_cell_celltype_identifier - ] - ] + attr=attr, + allowed=getattr(self.ontology_container_sfaira, attr), + attempted=[x for x in list(set(labels_mapped)) if x not in map_exceptions], ) # Add cell type IDs into object: # The IDs are not read from a source file but inferred based on the class name. # TODO this could be changed in the future, this allows this function to be used both on cell type name # mapping files with and without the ID in the third column. # This mapping blocks progression in the unit test if not deactivated. - ontology = getattr(self.ontology_container_sfaira, "cellontology_class") - ids_mapped = self.__project_name_to_id_obs( - ontology=ontology, - key_in=labels_mapped, - key_out=None, - map_exceptions=[ - adata_fields.unknown_celltype_identifier, - adata_fields.not_a_cell_celltype_identifier - ], - ) - results[adata_fields.cellontology_class] = labels_mapped - results[adata_fields.cellontology_id] = ids_mapped - if update_fields: - self.cellontology_id_obs_key = adata_fields.cellontology_id + results[getattr(adata_fields, attr)] = labels_mapped + self.__project_ontology_ids_obs(attr=attr, map_exceptions=map_exceptions, from_id=False, + adata_ids=adata_fields) else: - results[adata_fields.cellontology_class] = labels_original - results[adata_fields.cellontology_id] = [adata_fields.unknown_metadata_identifier] * self.adata.n_obs - results[adata_fields.cell_types_original] = labels_original - if update_fields: - self.cellontology_class_obs_key = adata_fields.cellontology_class - self.cell_types_original_obs_key = adata_fields.cell_types_original + results[getattr(adata_fields, attr)] = labels_original + results[getattr(adata_fields, attr) + adata_fields.onto_id_suffix] = \ + [adata_fields.unknown_metadata_identifier] * self.adata.n_obs + results[getattr(adata_fields, attr) + adata_fields.onto_original_suffix] = labels_original if copy: return pd.DataFrame(results, index=self.adata.obs.index) else: for k, v in results.items(): self.adata.obs[k] = v - def __project_name_to_id_obs( + def __impute_ontology_cols_obs( self, - ontology: OntologyHierarchical, - key_in: Union[str, list], - key_out: Union[str, None], - map_exceptions: list, + attr: str, + adata_ids: AdataIds, + ): + """ + Add missing ontology defined columns (symbol, ID, original) for a given ontology. + + 1) If original column is non-empty and symbol and ID are empty: + orginal column is projected to ontology and both symbol and ID are inferred. + Note that in this case, a label map is required. + 2) If ID column is non-empty or symbol is non-empty, an error is thrown. + a) If ID column is non-empty and symbol is empty, symbol is inferred. + b) If ID column is empty and symbol is non-empty, ID is inferred. + c) If ID column is non-empty and non-symbol is empty, symbol is inferred and over-written. + Note that this setting allows usage of data sets which were streamlined with a different ontology + version. + In all cases original is kept if it is set and is set to symbol otherwise. + 3) If original, ID and symbol columns are empty, no action is taken (meta data item was not set). + """ + ontology = self.get_ontology(k=attr) + col_symbol = getattr(adata_ids, attr) + col_id = getattr(adata_ids, attr) + self._adata_ids.onto_id_suffix + col_original = getattr(adata_ids, attr) + self._adata_ids.onto_original_suffix + if ontology is None: + # Fill with invalid ontology identifiers if no ontology was found. + self.adata.obs[col_id] = \ + [self._adata_ids.invalid_metadata_identifier for _ in range(self.adata.n_obs)] + self.adata.obs[col_original] = \ + [self._adata_ids.invalid_metadata_identifier for _ in range(self.adata.n_obs)] + self.adata.obs[col_symbol] = \ + [self._adata_ids.invalid_metadata_identifier for _ in range(self.adata.n_obs)] + else: + # Note that for symbol and ID, the columns may be filled but not streamlined according to the ontology, + # in that case the corresponding meta data is defined as absent. + # Check which level of meta data annotation is present. + # Symbols: + symbol_col_present = col_symbol in self.adata.obs.columns + symbol_col_streamlined = np.all([ + ontology.is_a_node_name(x) or x == self._adata_ids.unknown_metadata_identifier + for x in np.unique(self.adata.obs[col_symbol].values)]) if symbol_col_present else False + symbol_present = symbol_col_present and symbol_col_streamlined + # IDs: + id_col_present = col_id in self.adata.obs.columns + id_col_streamlined = np.all([ + ontology.is_a_node_id(x) or x == self._adata_ids.unknown_metadata_identifier + for x in np.unique(self.adata.obs[col_id].values)]) if id_col_present else False + id_present = id_col_present and id_col_streamlined + # Original annotation (free text): + original_present = col_original in self.adata.obs.columns + if original_present and not symbol_present and not id_present: # 1) + self.project_free_to_ontology(attr=attr, copy=False) + if symbol_present or id_present: # 2) + if symbol_present and not id_present: # 2a) + self.__project_ontology_ids_obs(attr=attr, from_id=False, adata_ids=adata_ids) + if not symbol_present and id_present: # 2b) + self.__project_ontology_ids_obs(attr=attr, from_id=True, adata_ids=adata_ids) + if symbol_present and id_present: # 2c) + self.__project_ontology_ids_obs(attr=attr, from_id=True, adata_ids=adata_ids) + if not original_present: + val = self.adata.obs[col_symbol] + self.adata.obs[col_original] = val + + def __project_ontology_ids_obs( + self, + attr: str, + adata_ids: AdataIds, + map_exceptions: Union[None, List[str]] = None, map_exceptions_value=None, + from_id: bool = False, ): """ Project ontology names to IDs for a given ontology in .obs entries. :param ontology: ontology to use when converting to IDs - :param key_in: name of obs_column containing names to convert or python list containing these values - :param key_out: name of obs_column to write the IDs or None. If None, a python list with the new values will be returned - :param map_exceptions: list of values that should not be mapped - :param map_exceptions_value: placeholder target value for values excluded from mapping + :param attr: name of obs_column containing names to convert or python list containing these values + :param map_exceptions: list of values that should not be mapped. + Defaults to unknown meta data identifier defined in ID object if None. + :param map_exceptions_value: placeholder target value for values excluded from mapping. + Defaults to unknown meta data identifier defined in ID object if None. + :param from_id: Whether to output ontology symbol or ID. :return: """ - assert ontology is not None, f"cannot project value for {key_in} because ontology is None" - assert isinstance(key_in, (str, list)), f"argument key_in needs to be of type str or list. Supplied" \ - f"type: {type(key_in)}" - input_values = self.adata.obs[key_in].values if isinstance(key_in, str) else key_in + ontology = self.get_ontology(k=attr) + assert ontology is not None, f"cannot project value for {attr} because ontology is None" + assert isinstance(attr, (str, list)), f"argument key_in needs to be of type str or list. Supplied" \ + f"type: {type(attr)}" + map_exceptions = map_exceptions if map_exceptions is not None else [adata_ids.unknown_metadata_identifier] + map_exceptions = [x.lower() for x in map_exceptions] + if map_exceptions_value is None: + # TODO this may be simplified in the future once all unknown meta data labels are the same. + if attr == "cell_type": + map_exceptions_value = adata_ids.unknown_metadata_identifier + else: + map_exceptions_value = adata_ids.unknown_metadata_identifier + col_name = getattr(adata_ids, attr) + if from_id: + col_name += adata_ids.onto_id_suffix + input_values = self.adata.obs[col_name].values map_vals = dict([ + (x, ontology.convert_to_name(x)) if from_id else (x, ontology.convert_to_id(x)) for x in np.unique([ xx for xx in input_values - if (xx not in map_exceptions and xx is not None) + if (xx.lower() not in map_exceptions and xx is not None) ]) ]) output_values = [ map_vals[x] if x in map_vals.keys() else map_exceptions_value for x in input_values ] - if isinstance(key_out, str): - self.adata.obs[key_out] = output_values - else: - return output_values + key_out = getattr(adata_ids, attr) if from_id else getattr(adata_ids, attr) + adata_ids.onto_id_suffix + self.adata.obs[key_out] = output_values @property def citation(self): @@ -1333,40 +1377,14 @@ def write_meta( meta = pandas.DataFrame(index=range(1)) # Expand table by variably cell-wise or data set-wise meta data: for x in self._adata_ids.controlled_meta_fields: - if x in ["cell_types_original", "cellontology_class", "cellontology_id"]: - continue - elif x in ["bio_sample", "individual", "tech_sample"] and \ - hasattr(self, f"{x}_obs_key") and \ - getattr(self, f"{x}_obs_key") is not None and \ - "*" in getattr(self, f"{x}_obs_key"): - batch_cols = [] - for batch_col in getattr(self, f"{x}_obs_key").split("*"): - if batch_col in self.adata.obs_keys(): - batch_cols.append(batch_col) - else: - # This should not occur in single data set loaders (see warning below) but can occur in - # streamlined data loaders if not all instances of the streamlined data sets have all columns - # in .obs set. - print(f"WARNING: attribute {x} of data set {self.id} was not found in column {batch_col}") - # Build a combination label out of all columns used to describe this group. - meta[getattr(self._adata_ids, x)] = (list(set([ - "_".join([str(xxx) for xxx in xx]) - for xx in zip(*[self.adata.obs[batch_col].values.tolist() for batch_col in batch_cols]) - ])),) - elif hasattr(self, f"{x}_obs_key") and getattr(self, f"{x}_obs_key") is not None: - meta[getattr(self._adata_ids, x)] = (self.adata.obs[getattr(self, f"{x}_obs_key")].unique(),) + if hasattr(self, f"{x}_obs_key") and getattr(self, f"{x}_obs_key") is not None: + col = getattr(self._adata_ids, x) + meta[col] = (self.adata.obs[col].unique(), ) + if x in self._adata_ids.ontology_constrained: + col = getattr(self._adata_ids, x + self._adata_ids.onto_id_suffix) + meta[col] = (self.adata.obs[col].unique(), ) else: meta[getattr(self._adata_ids, x)] = getattr(self, x) - # Add cell types into table if available: - if self.cell_types_original_obs_key is not None: - mappings = self.project_celltypes_to_ontology(copy=True, update_fields=False) - meta[self._adata_ids.cellontology_class] = (mappings[self._adata_ids.cellontology_class].unique(),) - meta[self._adata_ids.cellontology_id] = (mappings[self._adata_ids.cellontology_id].unique(),) - meta[self._adata_ids.cell_types_original] = (mappings[self._adata_ids.cell_types_original].unique(),) - else: - meta[self._adata_ids.cellontology_class] = " " - meta[self._adata_ids.cellontology_id] = " " - meta[self._adata_ids.cell_types_original] = " " meta.to_csv(fn_meta) def set_dataset_id( @@ -1404,7 +1422,7 @@ def additional_annotation_key(self, x: str): @property def annotated(self) -> Union[bool, None]: - if self.cellontology_id_obs_key is not None or self.cell_types_original_obs_key is not None: + if self.cell_type_obs_key is not None: return True else: if self.meta is None: @@ -1504,6 +1522,22 @@ def cell_line(self) -> Union[None, str]: def cell_line(self, x: str): self._cell_line = x + @property + def cell_type(self) -> Union[None, str]: + if self._cell_type is not None: + return self._cell_type + else: + if self.meta is None: + self.load_meta(fn=None) + if self.meta is not None and self._adata_ids.cell_type in self.meta.columns: + return self.meta[self._adata_ids.cell_type] + else: + return None + + @cell_type.setter + def cell_type(self, x: str): + self._cell_type = x + @property def data_dir(self): # Data is either directly in user supplied directory or in a sub directory if the overall directory is managed @@ -1808,150 +1842,6 @@ def primary_data(self, x: bool): attempted=x) self._primary_data = x - @property - def assay_sc_obs_key(self) -> str: - return self._assay_sc_obs_key - - @assay_sc_obs_key.setter - def assay_sc_obs_key(self, x: str): - self._assay_sc_obs_key = x - - @property - def assay_differentiation_obs_key(self) -> str: - return self._assay_differentiation_obs_key - - @assay_differentiation_obs_key.setter - def assay_differentiation_obs_key(self, x: str): - self._assay_differentiation_obs_key = x - - @property - def assay_type_differentiation_obs_key(self) -> str: - return self._assay_type_differentiation_obs_key - - @assay_type_differentiation_obs_key.setter - def assay_type_differentiation_obs_key(self, x: str): - self._assay_type_differentiation_obs_key = x - - @property - def bio_sample_obs_key(self) -> str: - return self._bio_sample_obs_key - - @bio_sample_obs_key.setter - def bio_sample_obs_key(self, x: str): - self._bio_sample_obs_key = x - - @property - def cell_line_obs_key(self) -> str: - return self._cell_line_obs_key - - @cell_line_obs_key.setter - def cell_line_obs_key(self, x: str): - self._cell_line_obs_key = x - - @property - def cellontology_class_obs_key(self) -> str: - return self._cellontology_class_obs_key - - @cellontology_class_obs_key.setter - def cellontology_class_obs_key(self, x: str): - self._cellontology_class_obs_key = x - - @property - def cellontology_id_obs_key(self) -> str: - return self._cellontology_id_obs_key - - @cellontology_id_obs_key.setter - def cellontology_id_obs_key(self, x: str): - self._cellontology_id_obs_key = x - - @property - def cell_types_original_obs_key(self) -> str: - return self._cell_types_original_obs_key - - @cell_types_original_obs_key.setter - def cell_types_original_obs_key(self, x: str): - self._cell_types_original_obs_key = x - - @property - def development_stage_obs_key(self) -> str: - return self._development_stage_obs_key - - @development_stage_obs_key.setter - def development_stage_obs_key(self, x: str): - self._development_stage_obs_key = x - - @property - def disease_obs_key(self) -> str: - return self._disease_obs_key - - @disease_obs_key.setter - def disease_obs_key(self, x: str): - self._disease_obs_key = x - - @property - def ethnicity_obs_key(self) -> str: - return self._ethnicity_obs_key - - @ethnicity_obs_key.setter - def ethnicity_obs_key(self, x: str): - self._ethnicity_obs_key = x - - @property - def individual_obs_key(self) -> str: - return self._individual_obs_key - - @individual_obs_key.setter - def individual_obs_key(self, x: str): - self._individual_obs_key = x - - @property - def organ_obs_key(self) -> str: - return self._organ_obs_key - - @organ_obs_key.setter - def organ_obs_key(self, x: str): - self._organ_obs_key = x - - @property - def organism_obs_key(self) -> str: - return self._organism_obs_key - - @organism_obs_key.setter - def organism_obs_key(self, x: str): - self._organism_obs_key = x - - @property - def sample_source_obs_key(self) -> str: - return self._sample_source_obs_key - - @sample_source_obs_key.setter - def sample_source_obs_key(self, x: str): - self._sample_source_obs_key = x - - @property - def sex_obs_key(self) -> str: - return self._sex_obs_key - - @sex_obs_key.setter - def sex_obs_key(self, x: str): - self._sex_obs_key = x - - @property - def state_exact_obs_key(self) -> str: - return self._state_exact_obs_key - - @state_exact_obs_key.setter - def state_exact_obs_key(self, x: str): - self._state_exact_obs_key = x - - @property - def tech_sample_obs_key(self) -> str: - return self._tech_sample_obs_key - - @tech_sample_obs_key.setter - def tech_sample_obs_key(self, x: str): - self._tech_sample_obs_key = x - @property def organ(self) -> Union[None, str]: if self._organ is not None: @@ -1984,6 +1874,8 @@ def organism(self) -> Union[None, str]: @organism.setter def organism(self, x: str): x = self._value_protection(attr="organism", allowed=self.ontology_container_sfaira.organism, attempted=x) + # Update ontology container so that correct ontologies are queried: + self.ontology_container_sfaira.organism_cache = x self._organism = x @property @@ -2061,22 +1953,6 @@ def tech_sample(self) -> Union[None, str]: def tech_sample(self, x: str): self._tech_sample = x - @property - def gene_id_ensembl_var_key(self) -> str: - return self._gene_id_ensembl_var_key - - @gene_id_ensembl_var_key.setter - def gene_id_ensembl_var_key(self, x: str): - self._gene_id_ensembl_var_key = x - - @property - def gene_id_symbols_var_key(self) -> str: - return self._gene_id_symbols_var_key - - @gene_id_symbols_var_key.setter - def gene_id_symbols_var_key(self, x: str): - self._gene_id_symbols_var_key = x - @property def year(self) -> Union[None, int]: if self._year is not None: @@ -2094,29 +1970,21 @@ def year(self, x: int): x = self._value_protection(attr="year", allowed=self.ontology_container_sfaira.year, attempted=x) self._year = x - @property - def ontology_celltypes(self): - return self.ontology_container_sfaira.cellontology_class - - @property - def ontology_organ(self): - return self.ontology_container_sfaira.organ - @property def celltypes_universe(self): if self._celltype_universe is None: self._celltype_universe = CelltypeUniverse( - cl=self.ontology_celltypes, - uberon=self.ontology_container_sfaira.organ, + cl=getattr(self.ontology_container_sfaira, "cell_type"), + uberon=getattr(self.ontology_container_sfaira, "organ"), ) return self._celltype_universe @property - def cell_ontology_map(self) -> dict: + def cell_type_map(self) -> dict: return self._ontology_class_map - @cell_ontology_map.setter - def cell_ontology_map(self, x: pd.DataFrame): + @cell_type_map.setter + def cell_type_map(self, x: pd.DataFrame): assert x.shape[1] in [2, 3], f"{x.shape} in {self.id}" assert x.columns[0] == self._adata_ids.classmap_source_key assert x.columns[1] == self._adata_ids.classmap_target_key @@ -2192,8 +2060,6 @@ def title(self): else: return self.__crossref_query(k="title") - # Private methods: - def _value_protection( self, attr: str, @@ -2235,7 +2101,7 @@ def _value_protection( if isinstance(allowed, OntologyHierarchical) and x in allowed.node_ids: attempted_clean.append(allowed.convert_to_name(x)) else: - raise ValueError(f"'{x}' is not a valid entry for {attr}.") + raise ValueError(f"'{x}' is not a valid entry for {attr} in data set {self.doi}.") else: raise ValueError(f"argument allowed of type {type(allowed)} is not a valid entry for {attr}.") # Flatten attempts if only one was made: @@ -2256,7 +2122,7 @@ def subset_cells(self, key, values): - "assay_differentiation" points to self.assay_differentiation_obs_key - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key - "cell_line" points to self.cell_line - - "cellontology_class" points to self.cellontology_class_obs_key + - "cell_type" points to self.cell_type_obs_key - "developmental_stage" points to self.developmental_stage_obs_key - "ethnicity" points to self.ethnicity_obs_key - "organ" points to self.organ_obs_key diff --git a/sfaira/data/dataloaders/base/dataset_group.py b/sfaira/data/dataloaders/base/dataset_group.py index 7ff2739fe..56f213c3b 100644 --- a/sfaira/data/dataloaders/base/dataset_group.py +++ b/sfaira/data/dataloaders/base/dataset_group.py @@ -42,6 +42,40 @@ def map_fn(inputs): return ds.id, e, +def merge_uns_from_list(adata_ls): + """ + Merge .uns from list of adata objects. + + Merges values for innert join of keys across all objects. This will retain uns streamlining. + Keeps shared uns values for a given key across data sets as single value (not a list of 1 unique value). + Other values are represented as a list of all unique values found. + """ + uns_keys = [list(x.uns.keys()) for x in adata_ls] + uns_keys_shared = set(uns_keys[0]) + for x in uns_keys[1:]: + uns_keys_shared = uns_keys_shared.intersection(set(x)) + uns_keys_shared = list(uns_keys_shared) + uns = {} + for k in uns_keys_shared: + uns_k = [] + for y in adata_ls: + x = y.uns[k] + if isinstance(x, list): + pass + elif isinstance(x, tuple): + x = list(x) + elif isinstance(x, np.ndarray): + x = x.tolist() + else: + x = [x] + uns_k.extend(x) + uns_k = np.sort(np.unique(uns_k)).tolist() + if len(uns_k) == 1: + uns_k = uns_k[0] + uns[k] = uns_k + return uns + + load_doc = \ """ :param remove_gene_version: Remove gene version string from ENSEMBL ID so that different versions in different data sets are superimposed. @@ -160,7 +194,10 @@ def streamline_metadata( clean_obs: bool = True, clean_var: bool = True, clean_uns: bool = True, - clean_obs_names: bool = True + clean_obs_names: bool = True, + keep_orginal_obs: bool = False, + keep_symbol_obs: bool = True, + keep_id_obs: bool = True, ): """ Streamline the adata instance in each data set to output format. @@ -173,6 +210,13 @@ def streamline_metadata( :param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp. :param clean_uns: Whether to delete non-streamlined fields in .uns. :param clean_obs_names: Whether to replace obs_names with a string comprised of dataset id and an increasing integer. + :param clean_obs_names: Whether to replace obs_names with a string comprised of dataset id and an increasing + integer. + :param keep_orginal_obs: For ontology-constrained .obs columns, whether to keep a column with original + annotation. + :param keep_symbol_obs: For ontology-constrained .obs columns, whether to keep a column with ontology symbol + annotation. + :param keep_id_obs: For ontology-constrained .obs columns, whether to keep a column with ontology ID annotation. :return: """ for x in self.ids: @@ -181,7 +225,10 @@ def streamline_metadata( clean_obs=clean_obs, clean_var=clean_var, clean_uns=clean_uns, - clean_obs_names=clean_obs_names + clean_obs_names=clean_obs_names, + keep_orginal_obs=keep_orginal_obs, + keep_symbol_obs=keep_symbol_obs, + keep_id_obs=keep_id_obs, ) def streamline_features( @@ -312,7 +359,7 @@ def write_ontology_class_map( for k, v in self.datasets.items(): if v.annotated: labels_original = np.sort(np.unique(np.concatenate([ - v.adata.obs[v.cell_types_original_obs_key].values + v.adata.obs[v.cell_type_original_obs_key].values ]))) tab.append(v.celltypes_universe.prepare_celltype_map_tab( source=labels_original, @@ -418,6 +465,7 @@ def adata(self): index_unique=None ) adata_concat.var = var_original + adata_concat.uns = merge_uns_from_list(adata_ls) adata_concat.uns[self._adata_ids.mapped_features] = match_ref_list[0] return adata_concat @@ -472,7 +520,7 @@ def project_celltypes_to_ontology(self, adata_fields: Union[AdataIds, None] = No :return: """ for _, v in self.datasets.items(): - v.project_celltypes_to_ontology(adata_fields=adata_fields, copy=copy) + v.project_free_to_ontology(adata_fields=adata_fields, copy=copy) def subset(self, key, values: Union[list, tuple, np.ndarray]): """ @@ -531,7 +579,7 @@ def subset_cells(self, key, values: Union[str, List[str]]): - "assay_sc" points to self.assay_sc_obs_key - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key - "cell_line" points to self.cell_line - - "cellontology_class" points to self.cellontology_class_obs_key + - "cell_type" points to self.cell_type_obs_key - "developmental_stage" points to self.developmental_stage_obs_key - "ethnicity" points to self.ethnicity_obs_key - "organ" points to self.organ_obs_key @@ -622,8 +670,17 @@ def __init__( # Collect all data loaders from files in directory: datasets = [] self._cwd = os.path.dirname(file_base) - collection_id = str(self._cwd.split("/")[-1]) - package_source = "sfaira" if str(self._cwd.split("/")[-5]) == "sfaira" else "sfairae" + try: + collection_id = str(self._cwd).split(os.sep)[-1] + package_source = str(self._cwd).split(os.sep)[-5] + if package_source == "sfaira": + pass + elif package_source == "sfaira_extension": + package_source = "sfairae" + else: + raise ValueError(f"invalid package source {package_source} for {self._cwd}, {self.collection_id}") + except IndexError as e: + raise IndexError(f"{e} for {self._cwd}") loader_pydoc_path_sfaira = "sfaira.data.dataloaders.loaders." loader_pydoc_path_sfairae = "sfaira_extension.data.dataloaders.loaders." loader_pydoc_path = loader_pydoc_path_sfaira if package_source == "sfaira" else loader_pydoc_path_sfairae @@ -698,7 +755,7 @@ def __init__( ) # Load cell type maps: for x in datasets_f: - x.load_ontology_class_map(fn=os.path.join(self._cwd, file_module + ".tsv")) + x.read_ontology_class_map(fn=os.path.join(self._cwd, file_module + ".tsv")) datasets.extend(datasets_f) keys = [x.id for x in datasets] @@ -721,15 +778,15 @@ def clean_ontology_class_map(self): fn_map = os.path.join(self._cwd, file_module + ".tsv") if os.path.exists(fn_map): # Access reading and value protection mechanisms from first data set loaded in group. - tab = list(self.datasets.values())[0]._read_class_map(fn=fn_map) + tab = list(self.datasets.values())[0]._read_ontology_class_map(fn=fn_map) # Checks that the assigned ontology class names appear in the ontology. list(self.datasets.values())[0]._value_protection( - attr="celltypes", + attr="cell_type", allowed=self.ontology_celltypes, attempted=[ - x for x in np.unique(tab[self._adata_ids.classmap_target_key].values).tolist() + x for x in np.unique(tab[self._adata_ids.classmap_target_key].values) if x not in [ - self._adata_ids.unknown_celltype_identifier, + self._adata_ids.unknown_metadata_identifier, self._adata_ids.not_a_cell_celltype_identifier ] ] @@ -737,12 +794,14 @@ def clean_ontology_class_map(self): # Adds a third column with the corresponding ontology IDs into the file. tab[self._adata_ids.classmap_target_id_key] = [ self.ontology_celltypes.convert_to_id(x) - if x != self._adata_ids.unknown_celltype_identifier and - x != self._adata_ids.not_a_cell_celltype_identifier - else self._adata_ids.unknown_celltype_identifier + if (x != self._adata_ids.unknown_metadata_identifier and + x != self._adata_ids.not_a_cell_celltype_identifier) + else self._adata_ids.unknown_metadata_identifier for x in tab[self._adata_ids.classmap_target_key].values ] - list(self.datasets.values())[0]._write_class_map(fn=fn_map, tab=tab) + # Get writing function from any (first) data set instance: + k = list(self.datasets.keys())[0] + self.datasets[k]._write_ontology_class_map(fn=fn_map, tab=tab) class DatasetSuperGroup: @@ -991,6 +1050,7 @@ def adata(self): index_unique=None ) adata_concat.var = var_original + adata_concat.uns = merge_uns_from_list(adata_ls) adata_concat.uns[self._adata_ids.mapped_features] = match_ref_list[0] return adata_concat @@ -1100,7 +1160,7 @@ def write_backed( self._adata_ids.author, self._adata_ids.cell_line, self._adata_ids.dataset, - self._adata_ids.cellontology_class, + self._adata_ids.cell_type, self._adata_ids.development_stage, self._adata_ids.normalization, self._adata_ids.organ, @@ -1166,6 +1226,9 @@ def streamline_metadata( clean_var: bool = True, clean_uns: bool = True, clean_obs_names: bool = True, + keep_orginal_obs: bool = False, + keep_symbol_obs: bool = True, + keep_id_obs: bool = True, ): """ Streamline the adata instance in each group and each data set to output format. @@ -1178,6 +1241,13 @@ def streamline_metadata( :param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp. :param clean_uns: Whether to delete non-streamlined fields in .uns. :param clean_obs_names: Whether to replace obs_names with a string comprised of dataset id and an increasing integer. + :param clean_obs_names: Whether to replace obs_names with a string comprised of dataset id and an increasing + integer. + :param keep_orginal_obs: For ontology-constrained .obs columns, whether to keep a column with original + annotation. + :param keep_symbol_obs: For ontology-constrained .obs columns, whether to keep a column with ontology symbol + annotation. + :param keep_id_obs: For ontology-constrained .obs columns, whether to keep a column with ontology ID annotation. :return: """ for x in self.dataset_groups: @@ -1187,7 +1257,10 @@ def streamline_metadata( clean_obs=clean_obs, clean_var=clean_var, clean_uns=clean_uns, - clean_obs_names=clean_obs_names + clean_obs_names=clean_obs_names, + keep_orginal_obs=keep_orginal_obs, + keep_symbol_obs=keep_symbol_obs, + keep_id_obs=keep_id_obs, ) def subset(self, key, values): @@ -1277,7 +1350,7 @@ def subset_cells(self, key, values: Union[str, List[str]]): - "assay_differentiation" points to self.assay_differentiation_obs_key - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key - "cell_line" points to self.cell_line - - "cellontology_class" points to self.cellontology_class_obs_key + - "cell_type" points to self.cell_type_obs_key - "developmental_stage" points to self.developmental_stage_obs_key - "ethnicity" points to self.ethnicity_obs_key - "organ" points to self.organ_obs_key @@ -1297,7 +1370,7 @@ def project_celltypes_to_ontology(self, adata_fields: Union[AdataIds, None] = No :return: """ for _, v in self.dataset_groups: - v.project_celltypes_to_ontology(adata_fields=adata_fields, copy=copy) + v.project_free_to_ontology(adata_fields=adata_fields, copy=copy) def write_config(self, fn: Union[str, os.PathLike]): """ diff --git a/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py b/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py index 3855f4c56..7eb65ba48 100644 --- a/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py +++ b/sfaira/data/dataloaders/databases/cellxgene/cellxgene_loader.py @@ -9,7 +9,7 @@ import uuid from sfaira.data.dataloaders.base import DatasetBase -from sfaira.consts import AdataIdsCellxgene +from sfaira.consts import AdataIdsCellxgene, AdataIdsCellxgeneHuman_v1_1_0, AdataIdsCellxgeneMouse_v1_1_0 from sfaira.consts.directories import CACHE_DIR_DATABASES_CELLXGENE from sfaira.data.dataloaders.databases.cellxgene.rest_helpers import get_collection, get_data from sfaira.data.dataloaders.databases.cellxgene.rest_helpers import CELLXGENE_PRODUCTION_ENDPOINT, DOWNLOAD_DATASET @@ -19,6 +19,63 @@ def cellxgene_fn(dir, dataset_id): return os.path.join(dir, dataset_id + ".h5ad") +def clean_cellxgene_meta_obs(k, val, adata_ids) -> Union[str, List[str]]: + """ + :param k: Found meta data name. + :param val: Found meta data entry. + :returns: Cleaned meta data entry. + """ + if k == "disease": + # TODO normal state label varies in disease annotation. This can be removed once streamlined. + val = ["healthy" if (v.lower() == "normal" or v.lower() == "healthy") else v for v in val] + elif k == "organ": + # Organ labels contain labels on tissue type also, such as 'UBERON:0001911 (cell culture)'. + val = [v.split(" ")[0] for v in val] + elif k == "organism": + organism_map = { + "Homo sapiens": "human", + "Mus musculus": "mouse", + } + val = [organism_map[v] if v in organism_map.keys() else v for v in val] + return val + + +def clean_cellxgene_meta_uns(k, val, adata_ids) -> Union[str, List[str]]: + """ + :param k: Found meta data name. + :param val: Found meta data entry. + :returns: Cleaned meta data entry. + """ + x_clean = [] + for v in val: + if k == "sex": + v = v[0] + else: + # Decide if labels are read from name or ontology ID: + if k == "disease" and (v["label"].lower() == "normal" or v["label"].lower() == "healthy"): + # TODO normal state label varies in disease annotation. This can be removed once streamlined. + v = "healthy" + elif k in ["assay_sc", "disease", "organ"] and \ + v["ontology_term_id"] != adata_ids.unknown_metadata_identifier: + v = v["ontology_term_id"] + else: + v = v["label"] + # Organ labels contain labels on tissue type also, such as 'UBERON:0001911 (cell culture)'. + if k == "organ": + v = v.split(" ")[0] + if k == "organism": + organism_map = { + "Homo sapiens": "human", + "Mus musculus": "mouse", + } + if v not in organism_map: + raise ValueError(f"value {v} not recognized") + v = organism_map[v] + if v != adata_ids.unknown_metadata_identifier and v != adata_ids.invalid_metadata_identifier: + x_clean.append(v) + return x_clean + + class Dataset(DatasetBase): """ This is a dataloader for downloaded h5ad from cellxgene. @@ -52,26 +109,10 @@ def __init__( sample_fn=sample_fn, sample_fns=sample_fns, ) + # General keys are defined in the shared IDs object. Further down, the species specific one is loaded to + # disambiguate species-dependent differences. self._adata_ids_cellxgene = AdataIdsCellxgene() self._collection = None - - # The h5ad objects from cellxgene follow a particular structure and the following attributes are guaranteed to - # be in place. Note that these point at the anndata instance and will only be available for evaluation after - # download. See below for attributes that are lazily available - self.cellontology_class_obs_key = self._adata_ids_cellxgene.cellontology_class - self.cellontology_id_obs_key = self._adata_ids_cellxgene.cellontology_id - self.cellontology_original_obs_key = self._adata_ids_cellxgene.cell_types_original - self.development_stage_obs_key = self._adata_ids_cellxgene.development_stage - self.disease_obs_key = self._adata_ids_cellxgene.disease - self.ethnicity_obs_key = self._adata_ids_cellxgene.ethnicity - self.sex_obs_key = self._adata_ids_cellxgene.sex - self.organ_obs_key = self._adata_ids_cellxgene.organism - self.state_exact_obs_key = self._adata_ids_cellxgene.state_exact - - self.gene_id_symbols_var_key = self._adata_ids_cellxgene.gene_id_symbols - - self._unknown_celltype_identifiers = self._adata_ids_cellxgene.unknown_celltype_identifier - self.collection_id = collection_id self.supplier = "cellxgene" doi = [x['link_url'] for x in self.collection["links"] if x['link_type'] == 'DOI'] @@ -94,34 +135,7 @@ def __init__( # Otherwise do not set property and resort to cell-wise labels. if isinstance(val, dict) or k == "sex": val = [val] - v_clean = [] - for v in val: - if k == "sex": - v = v[0] - else: - # Decide if labels are read from name or ontology ID: - if k == "disease" and (v["label"].lower() == "normal" or v["label"].lower() == "healthy"): - # TODO normal state label varies in disease annotation. This can be removed once streamlined. - v = "healthy" - elif k in ["assay_sc", "disease", "organ"] and \ - v["ontology_term_id"] != self._adata_ids_cellxgene.unknown_metadata_ontology_id_identifier: - v = v["ontology_term_id"] - else: - v = v["label"] - # Organ labels contain labels on tissue type also, such as 'UBERON:0001911 (cell culture)'. - if k == "organ": - v = v.split(" ")[0] - if k == "organism": - organism_map = { - "Homo sapiens": "human", - "Mus musculus": "mouse", - } - if v not in organism_map: - raise ValueError(f"value {v} not recognized") - v = organism_map[v] - if v != self._adata_ids_cellxgene.unknown_metadata_ontology_id_identifier and \ - v != self._adata_ids_cellxgene.invalid_metadata_identifier: - v_clean.append(v) + v_clean = clean_cellxgene_meta_uns(k=k, val=val, adata_ids=self._adata_ids_cellxgene) try: # Set as single element or list if multiple entries are given. if len(v_clean) == 1: @@ -131,8 +145,29 @@ def __init__( except ValueError as e: if verbose > 0: print(f"WARNING: {e} in {self.collection_id} and data set {self.id}") + + if self.organism == "human": + self._adata_ids_cellxgene = AdataIdsCellxgeneHuman_v1_1_0() + elif self.organism == "mouse": + self._adata_ids_cellxgene = AdataIdsCellxgeneMouse_v1_1_0() + else: + assert False, self.organism # Add author information. # TODO need to change this to contributor? setattr(self, "author", "cellxgene") + # The h5ad objects from cellxgene follow a particular structure and the following attributes are guaranteed to + # be in place. Note that these point at the anndata instance and will only be available for evaluation after + # download. See below for attributes that are lazily available + self.cell_type_obs_key = self._adata_ids_cellxgene.cell_type + self.development_stage_obs_key = self._adata_ids_cellxgene.development_stage + self.disease_obs_key = self._adata_ids_cellxgene.disease + self.ethnicity_obs_key = self._adata_ids_cellxgene.ethnicity + self.sex_obs_key = self._adata_ids_cellxgene.sex + self.organ_obs_key = self._adata_ids_cellxgene.organism + self.state_exact_obs_key = self._adata_ids_cellxgene.state_exact + + self.gene_id_symbols_var_key = self._adata_ids_cellxgene.feature_symbol + + self._unknown_celltype_identifiers = self._adata_ids_cellxgene.unknown_metadata_identifier @property def _collection_cache_dir(self): @@ -169,6 +204,10 @@ def _collection_dataset(self): def directory_formatted_doi(self) -> str: return self.collection_id + @property + def doi_cleaned_id(self): + return self.id + def load( self, remove_gene_version: bool = True, @@ -186,6 +225,7 @@ def load( load_raw=True, allow_caching=False, set_metadata=set_metadata, + adata_ids=self._adata_ids_cellxgene, **kwargs ) @@ -234,7 +274,7 @@ def show_summary(self): """ % (uuid_session, json.dumps(self._collection_dataset)), raw=True) -def load(data_dir, sample_fn, **kwargs): +def load(data_dir, sample_fn, adata_ids: AdataIdsCellxgene, **kwargs): """ Generalised load function for cellxgene-provided data sets. @@ -245,4 +285,8 @@ def load(data_dir, sample_fn, **kwargs): if adata.raw is not None: # TODO still need this? adata.X = adata.raw.X del adata.raw + for k in adata_ids.ontology_constrained: + col_name = getattr(adata_ids, k) + if col_name in adata.obs.columns: + adata.obs[col_name] = clean_cellxgene_meta_obs(k=k, val=adata.obs[col_name].values, adata_ids=adata_ids) return adata diff --git a/sfaira/data/dataloaders/export_adaptors/__init__.py b/sfaira/data/dataloaders/export_adaptors/__init__.py new file mode 100644 index 000000000..c8a133645 --- /dev/null +++ b/sfaira/data/dataloaders/export_adaptors/__init__.py @@ -0,0 +1 @@ +from sfaira.data.dataloaders.export_adaptors.cellxgene import cellxgene_export_adaptor diff --git a/sfaira/data/dataloaders/export_adaptors/cellxgene.py b/sfaira/data/dataloaders/export_adaptors/cellxgene.py new file mode 100644 index 000000000..21c8764f0 --- /dev/null +++ b/sfaira/data/dataloaders/export_adaptors/cellxgene.py @@ -0,0 +1,89 @@ +import anndata +import numpy as np +import scipy.sparse +from typing import Union + +from sfaira.consts.adata_fields import AdataIdsCellxgene + +DEFAULT_CELLXGENE_VERSION = "1_1_0" + + +def cellxgene_export_adaptor(adata: anndata.AnnData, adata_ids: AdataIdsCellxgene, version: Union[None, str]) \ + -> anndata.AnnData: + """ + Projects a streamlined data set to the export-ready cellxgene format. + """ + if version is None: + version = DEFAULT_CELLXGENE_VERSION + if version == "1_1_0": + return cellxgene_export_adaptor_1_1_0(adata=adata, adata_ids=adata_ids) + else: + raise ValueError(f"Did not recognise cellxgene schema version {version}") + + +def cellxgene_export_adaptor_1_1_0(adata: anndata.AnnData, adata_ids: AdataIdsCellxgene) -> anndata.AnnData: + """ + Cellxgene-schema 1.1.0 + """ + # Check input object characteristics: + cellxgene_cols = [getattr(adata_ids, x) for x in adata_ids.ontology_constrained] + \ + [getattr(adata_ids, x) for x in adata_ids.obs_keys + if x not in adata_ids.ontology_constrained] + \ + [getattr(adata_ids, x) + adata_ids.onto_id_suffix + for x in adata_ids.ontology_constrained] + for x in cellxgene_cols: + if x not in adata.obs.columns: + raise ValueError(f"Cannot streamlined data set {adata.uns['id']} to cellxgene format because meta data {x} " + f"is missing and the corresponding .obs column could not be written.\n" + f"Columns found were {adata.obs.columns}.") + # 1) Modify .uns + adata.uns["layer_descriptions"] = {"X": "raw"} + adata.uns["version"] = { + "corpora_encoding_version": "0.1.0", + "corpora_schema_version": "1.1.0", + } + adata.uns["contributors"] = { + "name": "sfaira", + "email": "https://github.com/theislab/sfaira/issues", + "institution": "sfaira", + } + # TODO port this into organism ontology handling. + # Infer organism from adata object. + organism = np.unique(adata.obs[adata_ids.organism].values)[0] + if organism == "mouse": + adata.uns["organism"] = "Mus musculus" + adata.uns["organism_ontology_term_id"] = "NCBITaxon:10090" + elif organism == "human": + adata.uns["organism"] = "Homo sapiens" + adata.uns["organism_ontology_term_id"] = "NCBITaxon:9606" + else: + raise ValueError(f"organism {organism} currently not supported by cellxgene schema") + # 2) Modify .obs + # Correct unknown cell type entries: + adata.obs[adata_ids.cell_type] = [ + x if x not in [adata_ids.unknown_metadata_identifier, adata_ids.not_a_cell_celltype_identifier] + else "native cell" for x in adata.obs[adata_ids.cell_type]] + adata.obs[adata_ids.cell_type + adata_ids.onto_id_suffix] = [ + x if x not in [adata_ids.unknown_metadata_identifier, adata_ids.not_a_cell_celltype_identifier] + else "CL:0000003" for x in adata.obs[adata_ids.cell_type + adata_ids.onto_id_suffix]] + # Reorder data frame to put ontology columns first: + adata.obs = adata.obs[cellxgene_cols + [x for x in adata.obs.columns if x not in cellxgene_cols]] + # 3) Modify .X + # Check if .X is counts: The conversion are based on the assumption that .X is csr. + assert isinstance(adata.X, scipy.sparse.csr_matrix), type(adata.X) + count_values = np.unique(np.asarray(adata.X.todense())) + if not np.all(count_values % 1. == 0.): + print(f"WARNING: not all count entries were counts, " + f"the maximum deviation from integer is " + f"{np.max([x % 1. if x % 1. < 0.5 else 1. - x % 1. for x in count_values])}. " + f"The count matrix is rounded.") + adata.X.data = np.rint(adata.X.data) + return adata + + +def cellxgene_export_adaptor_2_0_0(adata: anndata.AnnData, adata_ids: AdataIdsCellxgene) -> anndata.AnnData: + """ + Cellxgene-schema 2.0.0 + """ + adata = cellxgene_export_adaptor_1_1_0(adata=adata, adata_ids=adata_ids) + adata.var["feature_biotype"] = "gene" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py index 47e162a49..8434e4e33 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2017_09_004/human_isletoflangerhans_2017_smartseq2_enge_001.py @@ -26,7 +26,7 @@ def __init__(self, **kwargs): self.organism = "human" self.year = 2017 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "celltype" + self.cell_types_obs_key = "celltype" self.sample_source = "primary_tissue" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py index bc37cd92c..4cae46762 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_02_001/mouse_x_2018_microwellseq_han_x.py @@ -311,7 +311,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" # Only adult and neonatal samples are annotated: - self.cell_types_original_obs_key = "Annotation" \ + self.cell_types_obs_key = "Annotation" \ if sample_dev_stage_dict[self.sample_fn] in ["adult", "neonatal"] and \ self.sample_fn not in [ "NeontalBrain1_dge.txt.gz", diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml index 3a1b866d4..20047cb3c 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2018_08_067/human_laminapropriaofmucosaofcolon_2019_10xsequencing_kinchen_001.yaml @@ -52,7 +52,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: observation_wise: - cell_types_original_obs_key: "Cluster" + cell_type_obs_key: "Cluster" feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: "index" diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py index 802854189..4e9b6e19f 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_06_029/human_colonicepithelium_2019_10xsequencing_smilie_001.py @@ -26,7 +26,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py index c3aff84f9..bda83b0ee 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cell_2019_08_008/human_ileum_2019_10xsequencing_martin_001.py @@ -26,7 +26,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" self.gene_id_ensembl_var_key = "gene_ids" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py index 7a46ff91e..584016f99 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_celrep_2018_11_086/human_prostategland_2018_10xsequencing_henry_001.py @@ -31,7 +31,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py index d07a2ada2..6cc230afb 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cels_2016_08_011/human_pancreas_2016_indrop_baron_001.py @@ -25,7 +25,7 @@ def __init__(self, **kwargs): self.year = 2016 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py index b9c4c2657..4eee0bee5 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2016_08_020/human_pancreas_2016_smartseq2_segerstolpe_001.py @@ -26,7 +26,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "Characteristics[cell type]" + self.cell_type_obs_key = "Characteristics[cell type]" self.state_exact_obs_key = "Characteristics[disease]" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py index a6821a2af..203618129 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_cmet_2019_01_021/mouse_pancreas_2019_10xsequencing_thompson_x.py @@ -40,7 +40,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "names" self.gene_id_ensembl_var_key = "ensembl" - self.cell_types_original_obs_key = "celltypes" + self.cell_type_obs_key = "celltypes" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py b/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py index 7aff4b8f9..09c9e9dae 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_devcel_2020_01_033/human_lung_2020_10xsequencing_miller_001.py @@ -25,7 +25,7 @@ def __init__(self, **kwargs): self.year = 2020 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "Cell_type" + self.cell_type_obs_key = "Cell_type" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml b/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml index c099579ba..dd83c86a8 100644 --- a/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1016_j_neuron_2019_06_011/human_brain_2019_dropseq_polioudakis_001.yaml @@ -43,7 +43,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: "Index" observation_wise: - cell_types_original_obs_key: "celltype" + cell_type_obs_key: "celltype" feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: "index" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py b/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py index 614c855a5..7046f536a 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_nmeth_4407/human_brain_2017_droncseq_habib_001.py @@ -24,7 +24,7 @@ def __init__(self, **kwargs): self.year = 2017 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py index fda7a7ea1..232313aad 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41422_018_0099_2/human_testis_2018_10xsequencing_guo_001.py @@ -24,7 +24,7 @@ def __init__(self, **kwargs): self.year = 2018 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py index 9039758bd..87a7005df 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_018_06318_7/human_caudatelobeofliver_2018_10xsequencing_macparland_001.py @@ -23,7 +23,7 @@ def __init__(self, **kwargs): self.year = 2018 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "celltype" + self.cell_type_obs_key = "celltype" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py index 5cd22c4da..2f6a75b7b 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_10861_2/human_kidney_2019_droncseq_lake_001.py @@ -25,7 +25,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "celltype" + self.cell_type_obs_key = "celltype" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py index 87a40002c..5d437c4ae 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12464_3/human_x_2019_10xsequencing_szabo_001.py @@ -72,7 +72,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "Gene" self.gene_id_ensembl_var_key = "Accession" - self.cell_types_original_obs_key = "cell_ontology_class" + self.cell_type_obs_key = "cell_ontology_class" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py index f09696431..ab8580b92 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41467_019_12780_8/human_retina_2019_10xsequencing_menon_001.py @@ -22,7 +22,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py index 9af36b696..d246b6342 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_018_0698_6/human_placenta_2018_x_ventotormo_001.py @@ -31,7 +31,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "names" self.gene_id_ensembl_var_key = "ensembl" - self.cell_types_original_obs_key = "annotation" + self.cell_type_obs_key = "annotation" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py index 51ea692a6..98bc5df88 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1373_2/human_liver_2019_celseq2_aizarani_001.py @@ -23,7 +23,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py index eac9db452..35ae2652b 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1631_3/human_liver_2019_10xsequencing_ramachandran_001.py @@ -25,7 +25,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "annotation_lineage" + self.cell_type_obs_key = "annotation_lineage" self.state_exact_obs_key = "condition" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py index 65f6d5ac7..197e6ae83 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1652_y/human_liver_2019_10xsequencing_popescu_001.py @@ -23,7 +23,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "cell.labels" + self.cell_type_obs_key = "cell.labels" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml index 9b71b382d..03fe62b8b 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_019_1654_9/human_brain_2019_10x3v2sequencing_kanton_001.yaml @@ -45,7 +45,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: observation_wise: - cell_types_original_obs_key: + cell_type_obs_key: feature_wise: gene_id_ensembl_var_key: "ensembl" gene_id_symbols_var_key: "index" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py index fb6b17063..fd0b851cb 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2157_4/human_x_2020_microwellseq_han_x.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): self.sample_source = "primary_tissue" self.bio_sample_obs_key = "sample" - self.cell_types_original_obs_key = "celltype_specific" + self.cell_type_obs_key = "celltype_specific" self.development_stage_obs_key = "dev_stage" self.organ_obs_key = "organ" self.sex_obs_key = "sex" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml index 8fc3e343b..356db8dc2 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41586_020_2922_4/human_lung_2020_x_travaglini_001.yaml @@ -54,7 +54,7 @@ dataset_or_observation_wise: droplet_normal_lung_blood_scanpy.20200205.RC4.h5ad: "channel" facs_normal_lung_blood_scanpy.20200205.RC4.h5ad: "plate.barcode" observation_wise: - cell_types_original_obs_key: "free_annotation" + cell_type_obs_key: "free_annotation" feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: "index" diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py index 31d204574..d768f5df1 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41590_020_0602_z/human_colon_2020_10xsequencing_james_001.py @@ -27,7 +27,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" self.gene_id_ensembl_var_key = "gene_ids" - self.cell_types_original_obs_key = "cell_type" + self.cell_type_obs_key = "cell_type" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py index 2981bfda0..581abc96e 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_lung_2019_dropseq_braga_001.py @@ -24,7 +24,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "celltype" + self.cell_type_obs_key = "celltype" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py index c6e0b9f3e..bfc3fb3d8 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41591_019_0468_5/human_x_2019_10xsequencing_braga_x.py @@ -28,7 +28,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py b/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py index 07ccb350d..eb8387021 100644 --- a/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1038_s41593_019_0393_4/mouse_x_2019_10xsequencing_hove_001.py @@ -27,7 +27,7 @@ def __init__(self, **kwargs): self.year = 2019 self.bio_sample_obs_key = "sample" - self.cell_types_original_obs_key = "cluster" + self.cell_type_obs_key = "cluster" self.organ_obs_key = "organ" self.gene_id_ensembl_var_key = "ensembl" diff --git a/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py b/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py index 4f528b133..a7a8811c2 100644 --- a/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1073_pnas_1914143116/human_retina_2019_10xsequencing_voigt_001.py @@ -23,7 +23,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py b/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py index 4103cd1b6..903a7c625 100644 --- a/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1084_jem_20191130/human_x_2019_10xsequencing_wang_001.py @@ -32,7 +32,7 @@ def __init__(self, **kwargs): self.year = 2019 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py b/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py index 0fa4ae2ee..5e998bd36 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_2020_03_13_991455/human_lung_2020_10xsequencing_lukassen_001.py @@ -30,7 +30,7 @@ def __init__(self, **kwargs): self.year = 2020 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml index 171dcbee0..48ab8650a 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1101_2020_10_12_335331/human_blood_2020_10x_hao_001.yaml @@ -44,7 +44,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: 'Batch' observation_wise: - cell_types_original_obs_key: "celltype.l3" + cell_type_obs_key: "celltype.l3" feature_wise: gene_id_ensembl_var_key: gene_id_symbols_var_key: "names" diff --git a/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py b/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py index 16b7bda98..baa5bfd9a 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_661728/mouse_x_2019_x_pisco_x.py @@ -76,7 +76,7 @@ def __init__(self, **kwargs): f"{self.sample_fn}" self.download_url_meta = None - self.cell_types_original_obs_key = "cell_ontology_class" + self.cell_type_obs_key = "cell_ontology_class" self.development_stage_obs_key = "development_stage" self.sex_obs_key = "sex" # ToDo: further anatomical information for subtissue in "subtissue"? diff --git a/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py b/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py index 973d33b91..a26bdc42a 100644 --- a/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1101_753806/human_lungparenchyma_2020_10xsequencing_habermann_001.py @@ -42,7 +42,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "celltype" + self.cell_type_obs_key = "celltype" self.state_exact_obs_key = "Diagnosis" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py b/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py index a4e2781d5..24b828ebf 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aat5031/human_kidney_2019_10xsequencing_stewart_001.py @@ -32,7 +32,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" self.gene_id_ensembl_var_key = "ID" - self.cell_types_original_obs_key = "celltype" + self.cell_type_obs_key = "celltype" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py b/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py index d304fc729..80465d0ca 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aay3224/human_thymus_2020_10xsequencing_park_001.py @@ -33,7 +33,7 @@ def __init__(self, **kwargs): self.year = 2020 self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "Anno_level_fig1" + self.cell_type_obs_key = "Anno_level_fig1" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml b/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml index a91663fca..9288f7276 100644 --- a/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml +++ b/sfaira/data/dataloaders/loaders/d10_1126_science_aba7721/human_x_2020_scirnaseq_cao_001.yaml @@ -44,7 +44,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: "Experiment_batch" observation_wise: - cell_types_original_obs_key: "Main_cluster_name" + cell_type_obs_key: "Main_cluster_name" feature_wise: gene_id_ensembl_var_key: "gene_id" gene_id_symbols_var_key: "gene_short_name" diff --git a/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py b/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py index 20c5d99fe..ca63c18f2 100644 --- a/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py +++ b/sfaira/data/dataloaders/loaders/d10_1186_s13059_019_1906_x/human_x_2019_10xsequencing_madissoon_001.py @@ -45,7 +45,7 @@ def __init__(self, **kwargs): self.sample_source = "primary_tissue" self.gene_id_symbols_var_key = "index" - self.cell_types_original_obs_key = "Celltypes" + self.cell_type_obs_key = "Celltypes" self.set_dataset_id(idx=1) diff --git a/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py b/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py index 2ad219b55..19c4ffcb9 100644 --- a/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py +++ b/sfaira/data/dataloaders/loaders/d10_15252_embj_2018100811/human_retina_2019_10xsequencing_lukowski_001.py @@ -26,7 +26,7 @@ def __init__(self, **kwargs): self.gene_id_symbols_var_key = "index" self.gene_id_ensembl_var_key = "gene_ids" - self.cell_types_original_obs_key = "CellType" + self.cell_type_obs_key = "CellType" self.set_dataset_id(idx=1) diff --git a/sfaira/data/store/__init__.py b/sfaira/data/store/__init__.py index 48f12be63..14fd14808 100644 --- a/sfaira/data/store/__init__.py +++ b/sfaira/data/store/__init__.py @@ -1,3 +1,4 @@ -from sfaira.data.store.multi_store import load_store, DistributedStoreMultipleFeatureSpaceBase, \ +from sfaira.data.store.load_store import load_store +from sfaira.data.store.multi_store import DistributedStoreMultipleFeatureSpaceBase, \ DistributedStoresH5ad, DistributedStoresDao from sfaira.data.store.single_store import DistributedStoreSingleFeatureSpace diff --git a/sfaira/data/store/base.py b/sfaira/data/store/base.py new file mode 100644 index 000000000..283bd2395 --- /dev/null +++ b/sfaira/data/store/base.py @@ -0,0 +1,107 @@ +import abc +import anndata +import dask.dataframe +import numpy as np +import os +import pandas as pd +import sys +from typing import Dict, List, Union + + +class DistributedStoreBase(abc.ABC): + """ + Base class for store API for attribute typing and shared methods. + """ + + @property + @abc.abstractmethod + def adata_by_key(self) -> Dict[str, anndata.AnnData]: + pass + + @property + @abc.abstractmethod + def data_by_key(self): + pass + + @property + @abc.abstractmethod + def indices(self) -> Dict[str, np.ndarray]: + pass + + @property + @abc.abstractmethod + def genome_container(self): + pass + + @property + @abc.abstractmethod + def n_obs(self) -> int: + pass + + @property + @abc.abstractmethod + def n_vars(self): + pass + + @property + @abc.abstractmethod + def obs(self): + pass + + @property + @abc.abstractmethod + def obs_by_key(self) -> Dict[str, Union[pd.DataFrame, dask.dataframe.DataFrame]]: + pass + + @property + @abc.abstractmethod + def var_names(self): + pass + + @property + @abc.abstractmethod + def shape(self): + pass + + @property + @abc.abstractmethod + def var(self): + pass + + @property + @abc.abstractmethod + def X(self): + pass + + @abc.abstractmethod + def subset(self, attr_key, values: Union[str, List[str], None], + excluded_values: Union[str, List[str], None], verbose: int): + pass + + @abc.abstractmethod + def write_config(self, fn: Union[str, os.PathLike]): + pass + + @abc.abstractmethod + def load_config(self, fn: Union[str, os.PathLike]): + pass + + @abc.abstractmethod + def generator( + self, + idx: Union[np.ndarray, None], + batch_size: int, + obs_keys: List[str], + return_dense: bool, + randomized_batch_access: bool, + random_access: bool, + **kwargs + ) -> iter: + pass + + @property + def adata_memory_footprint(self) -> Dict[str, float]: + """ + Memory foot-print of data set k in MB. + """ + return dict([(k, sys.getsizeof(v) / np.power(1024, 2)) for k, v in self.adata_by_key.items()]) diff --git a/sfaira/data/store/batch_schedule.py b/sfaira/data/store/batch_schedule.py new file mode 100644 index 000000000..9a7f86e03 --- /dev/null +++ b/sfaira/data/store/batch_schedule.py @@ -0,0 +1,127 @@ +import numpy as np +from typing import List, Tuple + + +def _get_batch_start_ends(idx: np.ndarray, batch_size: int): + n_obs = len(idx) + remainder = n_obs % batch_size if n_obs > 0 else 0 + n_batches = int(n_obs // batch_size + int(remainder > 0)) if n_obs > 0 else 0 + batch_starts_ends = [ + (int(x * batch_size), int(np.minimum((x * batch_size) + batch_size, n_obs))) + for x in np.arange(0, n_batches) + ] + return batch_starts_ends + + +def _randomize_batch_start_ends(batch_starts_ends): + batch_range = np.arange(0, len(batch_starts_ends)) + np.random.shuffle(batch_range) + batch_starts_ends = [batch_starts_ends[i] for i in batch_range] + return batch_starts_ends + + +class BatchDesignBase: + + def __init__(self, retrieval_batch_size: int, randomized_batch_access: bool, random_access: bool, **kwargs): + self.retrieval_batch_size = retrieval_batch_size + self._idx = None + if randomized_batch_access and random_access: + raise ValueError("Do not use randomized_batch_access and random_access.") + self.randomized_batch_access = randomized_batch_access + self.random_access = random_access + + @property + def batch_bounds(self): + """ + Protects property from changing. + """ + return self._batch_bounds + + @property + def idx(self): + """ + Protects property from uncontrolled changing. + Changes to _idx require changes to _batch_bounds. + """ + return self._idx + + @idx.setter + def idx(self, x): + self._batch_bounds = _get_batch_start_ends(idx=x, batch_size=self.retrieval_batch_size) + self._idx = np.sort(x) # Sorted indices improve accession efficiency in some cases. + + @property + def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + """ + Yields index objects for one epoch of all data. + + These index objects are used by generators that have access to the data objects to build data batches. + Randomization is performed anew with every call to this property. + + :returns: Tuple of: + - Ordering of observations in epoch. + - Batch start and end indices for batch based on ordering defined in first output. + """ + raise NotImplementedError() + + +class BatchDesignBasic(BatchDesignBase): + + @property + def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + idx_proc = self.idx.copy() + if self.random_access: + np.random.shuffle(idx_proc) + batch_bounds = self.batch_bounds.copy() + if self.randomized_batch_access: + batch_bounds = _randomize_batch_start_ends(batch_starts_ends=batch_bounds) + return idx_proc, batch_bounds + + +class BatchDesignBalanced(BatchDesignBase): + + def __init__(self, grouping, group_weights: dict, randomized_batch_access: bool, random_access: bool, + **kwargs): + """ + :param grouping: Group label for each entry in idx. + :param group_weights: Group weight for each unique group in grouping. Does not have to normalise to a probability + distribution but is normalised in this function. The outcome vector is always of length idx. + """ + super(BatchDesignBalanced, self).__init__(randomized_batch_access=randomized_batch_access, + random_access=random_access, **kwargs) + if randomized_batch_access: + print("WARNING: randomized_batch_access==True is not a meaningful setting for BatchDesignBalanced.") + if not random_access: + print("WARNING: random_access==False is dangerous if you do not work with a large shuffle buffer " + "downstream of the sfaira generator.") + # Create integer group assignment array. + groups = np.sort(list(group_weights.keys())) + grouping_int = np.zeros((grouping.shape[0],), dtype="int32") - 1 + for i, x in enumerate(groups): + grouping_int[np.where(grouping == x)[0]] = i + assert np.all(grouping_int >= 0) + # Create sampling weights: Sampling weights are a probability distribution over groups. + weight_group = np.array([group_weights[x] for x in groups]) + p_obs = np.asarray(weight_group[grouping_int], dtype="float64") + p_obs = p_obs / np.sum(p_obs) + if np.any(p_obs == 0.): + raise ValueError(f"Down-sampling resulted in zero-probability weights on cells. " + f"Group weights: {weight_group}") + self.p_obs = p_obs + + @property + def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + # Re-sample index vector. + idx_proc = np.random.choice(a=self.idx, replace=True, size=len(self.idx), p=self.p_obs) + if not self.random_access: # Note: randomization is result from sampling above, need to revert if not desired. + idx_proc = np.sort(idx_proc) + batch_bounds = self.batch_bounds.copy() + if self.randomized_batch_access: + batch_bounds = _randomize_batch_start_ends(batch_starts_ends=batch_bounds) + return idx_proc, batch_bounds + + +BATCH_SCHEDULE = { + "base": BatchDesignBasic, + "balanced": BatchDesignBalanced, +} diff --git a/sfaira/data/store/generators.py b/sfaira/data/store/generators.py new file mode 100644 index 000000000..a72ee9f56 --- /dev/null +++ b/sfaira/data/store/generators.py @@ -0,0 +1,376 @@ +import anndata +import dask.array +import numpy as np +import pandas as pd +import scipy.sparse +from typing import Dict, List, Union + +from sfaira.data.store.batch_schedule import BATCH_SCHEDULE + + +def split_batch(x, obs): + """ + Splits retrieval batch into consumption batches of length 1. + + Often, end-user consumption batches would be observation-wise, ie yield a first dimension of length 1. + """ + for i in range(x.shape[0]): + yield x[i, :], obs.iloc[[i], :] + + +class GeneratorBase: + """ + A generator is a shallow class that is instantiated on a pointer to a data set in a Store instance. + + This class exposes an iterator generator through `.iterator`. + The iterator can often directly be used without a class like this one around it. + However, that often implies that the namespace with the pointer to the data set is destroyed after iterator + function declaration, which means that the pointer needs to be redefined for every full pass over the iterator. + The class around this property maintains the namespace that includes the pointer and its instance can be used to + avoid redefining the pointer every time the generator runs out and is re-called. + For this run time advantage to not be compromised by memory load and class initialisation run time cost induced by + actually copying data objects, it is important that the data object stored in this class is indeed a pointer. + This is the case for: + + - lazily loaded dask arrays + - anndata.Anndata view + + which have their own classes below. + """ + + @property + def iterator(self) -> iter: + raise NotImplementedError() + + @property + def obs_idx(self): + raise NotImplementedError() + + @property + def n_batches(self) -> int: + raise NotImplementedError() + + def adaptor(self, generator_type: str, **kwargs): + """ + The adaptor turns a python base generator into a different iteratable object, defined by generator_type. + + :param generator_type: Type of output iteratable. + - python base generator (no change to `.generator`) + - tensorflow dataset + - pytorch dataset + :returns: Modified iteratable (see generator_type). + """ + if generator_type == "python": + g = self.iterator + elif generator_type == "tensorflow": + import tensorflow as tf + g = tf.data.Dataset.from_generator(generator=self.iterator, **kwargs) + else: + raise ValueError(f"{generator_type} not recognized") + return g + + +class GeneratorSingle(GeneratorBase): + + batch_size: int + _obs_idx: Union[np.ndarray, None] + obs_keys: List[str] + var_idx: np.ndarray + + def __init__(self, batch_schedule, batch_size, map_fn, obs_idx, obs_keys, var_idx, **kwargs): + """ + + :param batch_schedule: str or class. + - "basic" + - "balanced" + - class: batch_schedule needs to be a class (not instance), subclassing BatchDesignBase. + :param batch_size: Emission batch size. Must be 1. + :param map_fn: Map function to apply to output tuple of raw generator. Each draw i from the generator is then: + `yield map_fn(x[i, var_idx], obs[i, obs_keys])` + :param obs_idx: np.ndarray: The cells to emit. + :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available + in self.adata_by_key. + :param var_idx: The features to emit. + """ + self._obs_idx = None + if not batch_size == 1: + raise ValueError(f"Only batch size==1 is supported, found {batch_size}.") + self.batch_schedule = batch_schedule + self.batch_size = batch_size + self.map_fn = map_fn + if isinstance(batch_schedule, str): + batch_schedule = BATCH_SCHEDULE[batch_schedule] + self.schedule = batch_schedule(**kwargs) + self.obs_idx = obs_idx + self.obs_keys = obs_keys + self.var_idx = var_idx + + def _validate_idx(self, idx: Union[np.ndarray, list]) -> np.ndarray: + """ + Validate global index vector. + """ + if len(idx) > 0: + assert np.max(idx) < self.n_obs, f"maximum of supplied index vector {np.max(idx)} exceeds number of " \ + f"modelled observations {self.n_obs}" + assert len(idx) == len(np.unique(idx)), f"repeated indices in idx: {len(idx) - len(np.unique(idx))}" + if isinstance(idx, np.ndarray): + assert len(idx.shape) == 1, idx.shape + assert idx.dtype == np.int + else: + assert isinstance(idx, list) + assert isinstance(idx[0], int) or isinstance(idx[0], np.int) + idx = np.asarray(idx) + return idx + + @property + def obs_idx(self): + return self._obs_idx + + @obs_idx.setter + def obs_idx(self, x): + """Allows emission of different iterator on same generator instance (using same dask array).""" + if x is None: + x = np.arange(0, self.n_obs) + else: + x = self._validate_idx(x) + x = np.sort(x) + # Only reset if they are actually different: + if self._obs_idx is not None and len(x) != len(self._obs_idx) or np.any(x != self._obs_idx): + self._obs_idx = x + self.schedule.idx = x + + @property + def n_obs(self) -> int: + raise NotImplementedError() + + @property + def n_batches(self) -> int: + return len(self.schedule.batch_bounds) + + +class GeneratorAnndata(GeneratorSingle): + + adata_dict: Dict[str, anndata._core.views.ArrayView] + return_dense: bool + single_object: bool + + def __init__(self, adata_dict, idx_dict_global, return_dense, **kwargs): + self.return_dense = return_dense + self.single_object = len(adata_dict.keys()) == 1 + self.idx_dict_global = idx_dict_global + self.adata_dict = adata_dict + super(GeneratorAnndata, self).__init__(**kwargs) + + @property + def n_obs(self) -> int: + return int(np.sum([v.n_obs for v in self.adata_dict.values()])) + + @property + def iterator(self) -> iter: + # Speed up access to single object by skipping index overlap operations: + + def g(): + obs_idx, batch_bounds = self.schedule.design + for s, e in batch_bounds: + idx_i = obs_idx[s:e] + # Match adata objects that overlap to batch: + if self.single_object: + idx_i_dict = dict([(k, np.sort(idx_i)) for k in self.adata_dict.keys()]) + else: + idx_i_set = set(idx_i) + # Return data set-wise index if global index is in target set. + idx_i_dict = dict([ + (k, np.sort([x2 for x1, x2 in zip(v1, v2) if x1 in idx_i_set])) + for k, (v1, v2) in self.idx_dict_global.items() + ]) + # Only retain non-empty. + idx_i_dict = dict([(k, v) for k, v in idx_i_dict.items() if len(v) > 0]) + if self.batch_size == 1: + # Emit each data set separately and avoid concatenation into larger chunks for emission. + for k, v in idx_i_dict.items(): + # I) Prepare data matrix. + x = self.adata_dict[k].X[v, :] + # Move from ArrayView to numpy if backed and dense: + if isinstance(x, anndata._core.views.ArrayView): + x = x.toarray() + if isinstance(x, anndata._core.views.SparseCSRView) or \ + isinstance(x, anndata._core.views.SparseCSCView): + x = x.toarray() + # Do dense conversion now so that col-wise indexing is not slow, often, dense conversion + # would be done later anyway. + if self.return_dense: + x = np.asarray(x.todense()) if isinstance(x, scipy.sparse.spmatrix) else x + if self.var_idx is not None: + x = x[:, self.var_idx] + # Prepare .obs. + obs = self.adata_dict[k].obs[self.obs_keys].iloc[v, :] + for x_i, obs_i in split_batch(x=x, obs=obs): + if self.map_fn is None: + yield x_i, obs_i + else: + output = self.map_fn(x_i, obs_i) + if output is not None: + yield output + else: + # Concatenates slices first before returning. Note that this is likely slower than emitting by + # observation in most scenarios. + # I) Prepare data matrix. + x = [ + self.adata_dict[k].X[v, :] + for k, v in idx_i_dict.items() + ] + # Move from ArrayView to numpy if backed and dense: + x = [ + xx.toarray() + if (isinstance(xx, anndata._core.views.ArrayView) or + isinstance(xx, anndata._core.views.SparseCSRView) or + isinstance(xx, anndata._core.views.SparseCSCView)) + else xx + for xx in x + ] + # Do dense conversion now so that col-wise indexing is not slow, often, dense conversion + # would be done later anyway. + if self.return_dense: + x = [np.asarray(xx.todense()) if isinstance(xx, scipy.sparse.spmatrix) else xx for xx in x] + is_dense = True + else: + is_dense = isinstance(x[0], np.ndarray) + # Concatenate blocks in observation dimension: + if len(x) > 1: + if is_dense: + x = np.concatenate(x, axis=0) + else: + x = scipy.sparse.vstack(x) + else: + x = x[0] + if self.var_idx is not None: + x = x[:, self.var_idx] + # Prepare .obs. + obs = pd.concat([ + self.adata_dict[k].obs[self.obs_keys].iloc[v, :] + for k, v in idx_i_dict.items() + ], axis=0, join="inner", ignore_index=True, copy=False) + if self.map_fn is None: + yield x, obs + else: + output = self.map_fn(x, obs) + if output is not None: + yield output + + return g + + +class GeneratorDask(GeneratorSingle): + + x: dask.array + obs: pd.DataFrame + + def __init__(self, x, obs, **kwargs): + self.x = x + super(GeneratorDask, self).__init__(**kwargs) + self.obs = obs[self.obs_keys] + # Redefine index so that .loc indexing can be used instead of .iloc indexing: + self.obs.index = np.arange(0, obs.shape[0]) + + @property + def n_obs(self) -> int: + return self.x.shape[0] + + @property + def iterator(self) -> iter: + # Can all data sets corresponding to one organism as a single array because they share the second dimension + # and dask keeps expression data and obs out of memory. + + def g(): + obs_idx, batch_bounds = self.schedule.design + x_temp = self.x[obs_idx, :] + obs_temp = self.obs.loc[self.obs.index[obs_idx], :] # TODO better than iloc? + for s, e in batch_bounds: + x_i = x_temp[s:e, :] + if self.var_idx is not None: + x_i = x_i[:, self.var_idx] + # Exploit fact that index of obs is just increasing list of integers, so we can use the .loc[] + # indexing instead of .iloc[]: + obs_i = obs_temp.loc[obs_temp.index[s:e], :] + # TODO place map_fn outside of for loop so that vectorisation in preprocessing can be used. + if self.batch_size == 1: + for x_ii, obs_ii in split_batch(x=x_i, obs=obs_i): + if self.map_fn is None: + yield x_ii, obs_ii + else: + output = self.map_fn(x_ii, obs_ii) + if output is not None: + yield output + else: + if self.map_fn is None: + yield x_i, obs_i + else: + output = self.map_fn(x_i, obs_i) + if output is not None: + yield output + + return g + + +class GeneratorMulti(GeneratorBase): + + generators: Dict[str, GeneratorSingle] + intercalated: bool + + def __init__(self, generators: Dict[str, GeneratorSingle], intercalated: bool = False): + self.generators = generators + self.intercalated = intercalated + self._ratios = None + + @property + def ratios(self): + """ + Define relative drawing frequencies from iterators for intercalation. + """ + if self._ratios is None: + gen_lens = np.array([v.n_batches for v in self.generators.values()]) + self._ratios = np.asarray(np.round(np.max(gen_lens) / np.asarray(gen_lens), 0), dtype="int64") + return self._ratios + + @property + def obs_idx(self): + return dict([(k, v.obs_idx) for k, v in self.generators.items()]) + + @obs_idx.setter + def obs_idx(self, x): + """Allows emission of different iterator on same generator instance (using same dask array).""" + if x is None: + x = dict([(k, None) for k in self.generators.keys()]) + for k in self.generators.keys(): + assert k in x.keys(), (x.keys(), self.generators.keys()) + self.generators[k].obs_idx = x[k] + self._ratios = None # Reset ratios. + + @property + def iterator(self) -> iter: + + if self.intercalated: + def g(): + # Document which generators are still yielding batches: + yielding = np.ones((self.ratios.shape[0],)) == 1. + iterators = [v.iterator() for v in self.generators.values()] + while np.any(yielding): + # Loop over one iterator length adjusted cycle of emissions. + for i, (g, n) in enumerate(zip(iterators, self.ratios)): + for _ in range(n): + try: + x = next(g) + yield x + except StopIteration: + yielding[i] = False + else: + def g(): + for gi in self.generators.values(): + for x in gi.iterator(): + yield x + + return g + + @property + def n_batches(self) -> int: + return np.sum([v.n_batches for v in self.generators.values()]) diff --git a/sfaira/data/store/io_dao.py b/sfaira/data/store/io_dao.py index 55f7fcab6..7da0bc8c6 100644 --- a/sfaira/data/store/io_dao.py +++ b/sfaira/data/store/io_dao.py @@ -105,6 +105,7 @@ def read_dao(store: Union[str, Path], use_dask: bool = True, columns: Union[None - AnnData with .X as dask array. - obs table separately as dataframe """ + assert not (obs_separate and x_separate), "either request obs_separate or x_separate, or neither, but not both" if use_dask: x = dask.array.from_zarr(url=path_x(store), component="X") else: @@ -131,5 +132,7 @@ def read_dao(store: Union[str, Path], use_dask: bool = True, columns: Union[None adata.obs = obs if obs_separate: return adata, obs - if x_separate: + elif x_separate: return adata, x + else: + return adata diff --git a/sfaira/data/store/load_store.py b/sfaira/data/store/load_store.py new file mode 100644 index 000000000..26231e257 --- /dev/null +++ b/sfaira/data/store/load_store.py @@ -0,0 +1,34 @@ +import os +from typing import List, Union + +from sfaira.data.store.multi_store import DistributedStoresDao, DistributedStoresH5ad, \ + DistributedStoreMultipleFeatureSpaceBase + + +def load_store(cache_path: Union[str, os.PathLike], store_format: str = "dao", + columns: Union[None, List[str]] = None) -> DistributedStoreMultipleFeatureSpaceBase: + """ + Instantiates a distributed store class. + + Note that any store is instantiated as a DistributedStoreMultipleFeatureSpaceBase. + This instances can be subsetted to the desired single feature space. + + :param cache_path: Store directory. + :param store_format: Format of store {"h5ad", "dao"}. + + - "h5ad": Returns instance of DistributedStoreH5ad and keeps data in memory. See also "h5ad_backed". + - "dao": Returns instance of DistributedStoreDoa (distributed access optimized). + - "h5ad_backed": Returns instance of DistributedStoreH5ad and keeps data as backed (out of memory). See also + "h5ad". + :param columns: Which columns to read into the obs copy in the output, see pandas.read_parquet(). + Only relevant if store_format is "dao". + :return: Instances of a distributed store class. + """ + if store_format == "h5ad": + return DistributedStoresH5ad(cache_path=cache_path, in_memory=True) + elif store_format == "dao": + return DistributedStoresDao(cache_path=cache_path, columns=columns) + elif store_format == "h5ad_backed": + return DistributedStoresH5ad(cache_path=cache_path, in_memory=False) + else: + raise ValueError(f"Did not recognize store_format {store_format}.") diff --git a/sfaira/data/store/multi_store.py b/sfaira/data/store/multi_store.py index dc4e51521..a041c7b86 100644 --- a/sfaira/data/store/multi_store.py +++ b/sfaira/data/store/multi_store.py @@ -1,18 +1,21 @@ -import abc import anndata +import dask.dataframe import numpy as np import os +import pandas as pd import pickle from typing import Dict, List, Tuple, Union from sfaira.consts import AdataIdsSfaira +from sfaira.data.store.base import DistributedStoreBase from sfaira.data.store.single_store import DistributedStoreSingleFeatureSpace, \ - DistributedStoreDao, DistributedStoreH5ad + DistributedStoreDao, DistributedStoreAnndata +from sfaira.data.store.generators import GeneratorMulti from sfaira.data.store.io_dao import read_dao from sfaira.versions.genomes.genomes import GenomeContainer -class DistributedStoreMultipleFeatureSpaceBase(abc.ABC): +class DistributedStoreMultipleFeatureSpaceBase(DistributedStoreBase): """ Umbrella class for a dictionary over multiple instances DistributedStoreSingleFeatureSpace. @@ -38,11 +41,11 @@ def stores(self, x: Dict[str, DistributedStoreSingleFeatureSpace]): raise NotImplementedError("cannot set this attribute, it s defined in constructor") @property - def genome_containers(self) -> Dict[str, Union[GenomeContainer, None]]: + def genome_container(self) -> Dict[str, Union[GenomeContainer, None]]: return dict([(k, v.genome_container) for k, v in self._stores.items()]) - @genome_containers.setter - def genome_containers(self, x: Union[GenomeContainer, Dict[str, GenomeContainer]]): + @genome_container.setter + def genome_container(self, x: Union[GenomeContainer, Dict[str, GenomeContainer]]): if isinstance(x, GenomeContainer): # Transform into dictionary first. organisms = [k for k, v in self.stores.items()] @@ -79,6 +82,14 @@ def data_by_key(self): """ return dict([(kk, vv) for k, v in self.stores.items() for kk, vv in v.data_by_key.items()]) + @property + def obs_by_key(self) -> Dict[str, Union[pd.DataFrame, dask.dataframe.DataFrame]]: + """ + Dictionary of all anndata instances for each selected data set in store, sub-setted by selected cells, for each + stores. + """ + return dict([(k, v.obs) for k, v in self.adata_by_key.items()]) + @property def var_names(self) -> Dict[str, List[str]]: """ @@ -94,7 +105,14 @@ def n_vars(self) -> Dict[str, int]: return dict([(k, v.n_vars) for k, v in self.stores.items()]) @property - def n_obs(self) -> Dict[str, int]: + def n_obs(self) -> int: + """ + Dictionary of number of observations across stores. + """ + return np.asarray(np.sum([v.n_obs for v in self.stores.values()]), dtype="int32") + + @property + def n_obs_dict(self) -> Dict[str, int]: """ Dictionary of number of observations by store. """ @@ -107,6 +125,13 @@ def obs(self): """ return dict([(k, v.obs) for k, v in self.stores.items()]) + @property + def var(self): + """ + Dictionaries of .var tables by store, only including non-empty stores. + """ + return dict([(k, v.var) for k, v in self.stores.items()]) + @property def X(self): """ @@ -134,7 +159,7 @@ def subset(self, attr_key, values: Union[str, List[str], None] = None, - "assay_sc" points to self.assay_sc_obs_key - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key - "cell_line" points to self.cell_line - - "cellontology_class" points to self.cellontology_class_obs_key + - "cell_type" points to self.cell_type_obs_key - "developmental_stage" points to self.developmental_stage_obs_key - "ethnicity" points to self.ethnicity_obs_key - "organ" points to self.organ_obs_key @@ -175,15 +200,16 @@ def load_config(self, fn: Union[str, os.PathLike]): with open(fn, 'rb') as f: indices = pickle.load(f) # Distribute indices to corresponding stores by matched keys. - keys_not_found = list(indices.keys()) + keys_found = [] for k, v in self.stores.items(): indices_k = {} - for i, (kk, vv) in enumerate(indices.items()): + for kk, vv in indices.items(): if kk in v.adata_by_key.keys(): indices_k[kk] = vv - del keys_not_found[i] + keys_found.append(kk) self.stores[k].indices = indices_k # Make sure all declared data were assigned to stores: + keys_not_found = list(set(list(indices.keys())).difference(set(keys_found))) if len(keys_not_found) > 0: raise ValueError(f"did not find object(s) with name(s) in store: {keys_not_found}") @@ -192,7 +218,7 @@ def generator( idx: Union[Dict[str, Union[np.ndarray, None]], None] = None, intercalated: bool = True, **kwargs - ) -> Tuple[iter, int]: + ) -> GeneratorMulti: """ Emission of batches from unbiased generators of all stores. @@ -201,41 +227,51 @@ def generator( :param idx: :param intercalated: Whether to do sequential or intercalated emission. :param kwargs: See parameters of DistributedStore*.generator(). + :return: Generator function which yields batch_size at every invocation. + The generator returns a tuple of (.X, .obs). """ if idx is None: idx = dict([(k, None) for k in self.stores.keys()]) for k in self.stores.keys(): assert k in idx.keys(), (idx.keys(), self.stores.keys()) - generators = [ - v.generator(idx=idx[k], **kwargs) - for k, v in self.stores.items() - ] - generator_fns = [x[0]() for x in generators] - generator_len = [x[1] for x in generators] - - if intercalated: - # Define relative drawing frequencies from iterators for intercalation. - ratio = np.asarray(np.round(np.max(generator_len) / np.asarray(generator_len), 0), dtype="int64") - - def generator(): - # Document which generators are still yielding batches: - yielding = np.ones((ratio.shape[0],)) == 1. - while np.any(yielding): - # Loop over one iterator length adjusted cycle of emissions. - for i, (g, n) in enumerate(zip(generator_fns, ratio)): - for _ in range(n): - try: - x = next(g) - yield x - except StopIteration: - yielding[i] = False - else: - def generator(): - for g in generator_fns: - for x in g(): - yield x - - return generator, int(np.sum(generator_len)) + generators = dict([(k, v.generator(idx=idx[k], **kwargs)) for k, v in self.stores.items()]) + return GeneratorMulti(generators=generators, intercalated=intercalated) + + +class DistributedStoresAnndata(DistributedStoreMultipleFeatureSpaceBase): + + def __init__(self, adatas: Union[anndata.AnnData, List[anndata.AnnData], Tuple[anndata.AnnData]]): + # Collect all data loaders from files in directory: + self._adata_ids_sfaira = AdataIdsSfaira() + adata_by_key = {} + indices = {} + if isinstance(adatas, anndata.AnnData): + adatas = [adatas] + for adata in adatas: + organism = adata.uns[self._adata_ids_sfaira.organism] + if isinstance(organism, list): + if len(organism) == 1: + organism = organism[0] + assert isinstance(organism, str), organism + else: + raise ValueError(f"tried to register mixed organism data set ({organism})") + adata_id = adata.uns[self._adata_ids_sfaira.id] + # Make up a new merged ID for data set indexing if there is a list of IDs in .uns. + if isinstance(adata_id, list): + adata_id = "_".join(adata_id) + if organism not in adata_by_key.keys(): + adata_by_key[organism] = {} + indices[organism] = {} + try: + adata_by_key[organism][adata_id] = adata + indices[organism][adata_id] = np.arange(0, adata.n_obs) + except TypeError as e: + raise TypeError(f"{e} for {organism} or {adata.uns[self._adata_ids_sfaira.id]}") + stores = dict([ + (k, DistributedStoreAnndata(adata_by_key=adata_by_key[k], indices=indices[k], in_memory=True)) + for k in adata_by_key.keys() + ]) + super(DistributedStoresAnndata, self).__init__(stores=stores) class DistributedStoresDao(DistributedStoreMultipleFeatureSpaceBase): @@ -267,9 +303,9 @@ def __init__(self, cache_path: Union[str, os.PathLike], columns: Union[None, Lis adata_by_key[organism] = {} x_by_key[organism] = {} indices[organism] = {} - adata_by_key[organism][adata.uns["id"]] = adata - x_by_key[organism][adata.uns["id"]] = x - indices[organism][adata.uns["id"]] = np.arange(0, adata.n_obs) + adata_by_key[organism][adata.uns[self._adata_ids_sfaira.id]] = adata + x_by_key[organism][adata.uns[self._adata_ids_sfaira.id]] = x + indices[organism][adata.uns[self._adata_ids_sfaira.id]] = np.arange(0, adata.n_obs) self._x_by_key = x_by_key stores = dict([ (k, DistributedStoreDao(adata_by_key=adata_by_key[k], x_by_key=x_by_key[k], indices=indices[k], @@ -305,34 +341,10 @@ def __init__(self, cache_path: Union[str, os.PathLike], in_memory: bool = False) if organism not in adata_by_key.keys(): adata_by_key[organism] = {} indices[organism] = {} - adata_by_key[organism][adata.uns["id"]] = adata - indices[organism][adata.uns["id"]] = np.arange(0, adata.n_obs) + adata_by_key[organism][adata.uns[self._adata_ids_sfaira.id]] = adata + indices[organism][adata.uns[self._adata_ids_sfaira.id]] = np.arange(0, adata.n_obs) stores = dict([ - (k, DistributedStoreH5ad(adata_by_key=adata_by_key[k], indices=indices[k], in_memory=in_memory)) + (k, DistributedStoreAnndata(adata_by_key=adata_by_key[k], indices=indices[k], in_memory=in_memory)) for k in adata_by_key.keys() ]) super(DistributedStoresH5ad, self).__init__(stores=stores) - - -def load_store(cache_path: Union[str, os.PathLike], store_format: str = "dao", - columns: Union[None, List[str]] = None) -> Union[DistributedStoresH5ad, DistributedStoresDao]: - """ - Instantiates a distributed store class. - - :param cache_path: Store directory. - :param store_format: Format of store {"h5ad", "dao"}. - - - "h5ad": Returns instance of DistributedStoreH5ad. - - "dao": Returns instance of DistributedStoreDoa (distributed access optimized). - :param columns: Which columns to read into the obs copy in the output, see pandas.read_parquet(). - Only relevant if store_format is "dao". - :return: Instances of a distributed store class. - """ - if store_format == "anndata": - return DistributedStoresH5ad(cache_path=cache_path, in_memory=True) - elif store_format == "dao": - return DistributedStoresDao(cache_path=cache_path, columns=columns) - elif store_format == "h5ad": - return DistributedStoresH5ad(cache_path=cache_path, in_memory=False) - else: - raise ValueError(f"Did not recognize store_format {store_format}.") diff --git a/sfaira/data/store/single_store.py b/sfaira/data/store/single_store.py index c9e8e66fe..4f4cd19ca 100644 --- a/sfaira/data/store/single_store.py +++ b/sfaira/data/store/single_store.py @@ -2,18 +2,17 @@ import anndata import dask.array import dask.dataframe -import h5py import numpy as np import os import pandas as pd import pickle import scipy.sparse -import sys -import time from typing import Dict, List, Tuple, Union from sfaira.consts import AdataIdsSfaira, OCS from sfaira.data.dataloaders.base.utils import is_child, UNS_STRING_META_IN_OBS +from sfaira.data.store.base import DistributedStoreBase +from sfaira.data.store.generators import GeneratorAnndata, GeneratorDask, GeneratorSingle from sfaira.versions.genomes.genomes import GenomeContainer """ @@ -35,16 +34,13 @@ """ -def _process_batch_size(x: int, idx: np.ndarray) -> int: - if x > len(idx): - batch_size_new = len(idx) - print(f"WARNING: reducing retrieval batch size according to data availability in store " - f"from {x} to {batch_size_new}") - x = batch_size_new - return x +def _process_batch_size(batch_size: int, retrival_batch_size: int) -> Tuple[int, int]: + if batch_size != 1: + raise ValueError("batch size is only supported as 1") + return batch_size, retrival_batch_size -class DistributedStoreSingleFeatureSpace: +class DistributedStoreSingleFeatureSpace(DistributedStoreBase): """ Data set group class tailored to data access requirements common in high-performance computing (HPC). @@ -91,22 +87,6 @@ def idx(self) -> np.ndarray: idx_global = np.arange(0, np.sum([len(v) for v in self.indices.values()])) return idx_global - def _validate_idx(self, idx: Union[np.ndarray, list]) -> np.ndarray: - """ - Validate global index vector. - """ - assert np.max(idx) < self.n_obs, f"maximum of supplied index vector {np.max(idx)} exceeds number of modelled " \ - f"observations {self.n_obs}" - assert len(idx) == len(np.unique(idx)), f"there were {len(idx) - len(np.unique(idx))} repeated indices in idx" - if isinstance(idx, np.ndarray): - assert len(idx.shape) == 1, idx.shape - assert idx.dtype == np.int - else: - assert isinstance(idx, list) - assert isinstance(idx[0], int) or isinstance(idx[0], np.int) - idx = np.asarray(idx) - return idx - @property def organisms_by_key(self) -> Dict[str, str]: """ @@ -160,13 +140,6 @@ def data_by_key(self): """ return dict([(k, v.X) for k, v in self.adata_by_key.items()]) - @property - def adata_memory_footprint(self) -> Dict[str, float]: - """ - Memory foot-print of data set k in MB. - """ - return dict([(k, sys.getsizeof(v) / np.power(1024, 2)) for k, v in self.adata_by_key.items()]) - @property def indices(self) -> Dict[str, np.ndarray]: """ @@ -246,7 +219,7 @@ def get_subset_idx(self, attr_key, values: Union[str, List[str], None], - "assay_sc" points to self.assay_sc_obs_key - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key - "cell_line" points to self.cell_line - - "cellontology_class" points to self.cellontology_class_obs_key + - "cell_type" points to self.cell_type_obs_key - "developmental_stage" points to self.developmental_stage_obs_key - "ethnicity" points to self.ethnicity_obs_key - "organ" points to self.organ_obs_key @@ -266,9 +239,11 @@ def get_subset_idx(self, attr_key, values: Union[str, List[str], None], def get_idx(adata, obs, k, v, xv, dataset): # Use cell-wise annotation if data set-wide maps are ambiguous: # This can happen if the different cell-wise annotations are summarised as a union in .uns. - if getattr(self._adata_ids_sfaira, k) in adata.uns.keys() and \ - adata.uns[getattr(self._adata_ids_sfaira, k)] != UNS_STRING_META_IN_OBS and \ - getattr(self._adata_ids_sfaira, k) not in obs.columns: + read_from_uns = (getattr(self._adata_ids_sfaira, k) in adata.uns.keys() and + adata.uns[getattr(self._adata_ids_sfaira, k)] != UNS_STRING_META_IN_OBS and + getattr(self._adata_ids_sfaira, k) not in obs.columns) + read_from_obs = not read_from_uns and getattr(self._adata_ids_sfaira, k) in obs.columns + if read_from_uns: values_found = adata.uns[getattr(self._adata_ids_sfaira, k)] if isinstance(values_found, np.ndarray): values_found = values_found.tolist() @@ -279,14 +254,11 @@ def get_idx(adata, obs, k, v, xv, dataset): else: # Replicate unique property along cell dimension. values_found = [values_found[0] for _ in range(adata.n_obs)] + elif read_from_obs: + values_found = obs[getattr(self._adata_ids_sfaira, k)].values else: - values_found = None - if values_found is None: - if getattr(self._adata_ids_sfaira, k) in obs.columns: - values_found = obs[getattr(self._adata_ids_sfaira, k)].values - else: - values_found = [] - print(f"WARNING: did not find attribute {k} in data set {dataset}") + values_found = [] + print(f"WARNING: did not find attribute {k} in data set {dataset}") values_found_unique = np.unique(values_found) try: ontology = getattr(self.ontology_container, k) @@ -340,7 +312,7 @@ def subset(self, attr_key, values: Union[str, List[str], None] = None, - "assay_sc" points to self.assay_sc_obs_key - "assay_type_differentiation" points to self.assay_type_differentiation_obs_key - "cell_line" points to self.cell_line - - "cellontology_class" points to self.cellontology_class_obs_key + - "cell_type" points to self.cell_type_obs_key - "developmental_stage" points to self.developmental_stage_obs_key - "ethnicity" points to self.ethnicity_obs_key - "organ" points to self.organ_obs_key @@ -353,7 +325,7 @@ def subset(self, attr_key, values: Union[str, List[str], None] = None, """ self.indices = self.get_subset_idx(attr_key=attr_key, values=values, excluded_values=excluded_values) if self.n_obs == 0 and verbose > 0: - print("WARNING: store is now empty.") + print(f"WARNING: store is now empty after subsetting {attr_key} for {values}, excluding {excluded_values}.") def write_config(self, fn: Union[str, os.PathLike]): """ @@ -416,54 +388,103 @@ def n_obs(self) -> int: def shape(self) -> Tuple[int, int]: return self.n_obs, self.n_vars - @abc.abstractmethod - def _generator( + def _index_curation_helper( self, - idx_gen: iter, - var_idx: Union[np.ndarray, None], - obs_keys: List[str], - ) -> iter: - pass - - def _generator_helper( - self, - idx: Union[np.ndarray, None], batch_size: int, - ) -> Tuple[Union[np.ndarray, None], Union[np.ndarray, None], int]: + retrival_batch_size: int, + ) -> Tuple[Union[np.ndarray, None], int, int]: + """ + Process indices and batch size input for generator production. + + Feature indices are formatted based on previously loaded genome container. + + :param batch_size: Number of observations read from disk in each batched access (generator invocation). + :return: Tuple: + - var_idx: Processed feature index vector for generator to access. + - batch_size: Processed batch size for generator to access. + - retrival_batch_size: Processed retrieval batch size for generator to access. + """ # Make sure that features are ordered in the same way in each object so that generator yields consistent cell # vectors. var_names = self._validate_feature_space_homogeneity() # Use feature space sub-selection based on assembly if provided, will use full feature space otherwise. if self.genome_container is not None: var_names_target = self.genome_container.ensembl - var_idx = np.sort([var_names.index(x) for x in var_names_target]) # Check if index vector is just full ordered list of indices, in this case, sub-setting is unnecessary. - if len(var_idx) == len(var_names) and np.all(var_idx == np.arange(0, len(var_names))): + if len(var_names_target) == len(var_names) and np.all(var_names_target == var_names): var_idx = None + else: + # Check if variable names are continuous stretch in reference list, indexing this is much faster. + # Note: There is about 5 sec to be saved on a call because if len(var_names_target) calls to .index + # on a list of length var_names are avoided. + # One example in this would save about 5 sec would be selection of protein coding genes from a full + # gene space in which protein coding genes grouped together (this is not the case in the standard + # assembly). + idx_first = var_names.index(var_names_target[0]) + idx_last = idx_first + len(var_names_target) + if idx_last <= len(var_names) and np.all(var_names_target == var_names[idx_first:idx_last]): + var_idx = np.arange(idx_first, idx_last) + else: + var_idx = np.sort([var_names.index(x) for x in var_names_target]) else: var_idx = None - if idx is not None: - idx = self._validate_idx(idx) - batch_size = _process_batch_size(x=batch_size, idx=idx) - return idx, var_idx, batch_size + # Select all cells if idx was None: + batch_size, retrival_batch_size = _process_batch_size(batch_size=batch_size, + retrival_batch_size=retrival_batch_size) + return var_idx, batch_size, retrival_batch_size + + @abc.abstractmethod + def _get_generator( + self, + batch_schedule, + obs_idx: np.ndarray, + var_idx: Union[np.ndarray, None], + map_fn, + obs_keys: List[str], + **kwargs + ) -> iter: + """ + Yields an instance of GeneratorSingle which can emit an iterator over the data defined in the arguments here. + + :param obs_idx: The observations to emit. + :param var_idx: The features to emit. + :param map_fn: Map functino to apply to output tuple of raw generator. Each draw i from the generator is then: + `yield map_fn(x[i, var_idx], obs[i, obs_keys])` + :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available + in self.adata_by_key. + :return: GeneratorSingle instance. + """ + pass def generator( self, idx: Union[np.ndarray, None] = None, batch_size: int = 1, + retrieval_batch_size: int = 128, + map_fn=None, obs_keys: List[str] = [], return_dense: bool = True, randomized_batch_access: bool = False, random_access: bool = False, + batch_schedule: str = "base", **kwargs - ) -> iter: + ) -> GeneratorSingle: """ - Yields an unbiased generator over observations in the contained data sets. + Yields an instance of a generator class over observations in the contained data sets. - :param idx: Global idx to query from store. These is an array with indicies corresponding to a contiuous index + Multiple such instances can be emitted by a single store class and point to data stored in this store class. + Effectively, these generators are heavily reduced pointers to the data in an instance of self. + A common use case is the instantiation of a training data generator and a validation data generator over a data + subset defined in this class. + + :param idx: Global idx to query from store. These is an array with indices corresponding to a contiuous index along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of - self.adata_by_key. - :param batch_size: Number of observations read from disk in each batched access (generator invocation). + self.adata_by_key. If None, all observations are selected. + :param batch_size: Number of observations to yield in each access (generator invocation). + :param retrieval_batch_size: Number of observations read from disk in each batched access (data-backend generator + invocation). + :param map_fn: Map functino to apply to output tuple of raw generator. Each draw i from the generator is then: + `yield map_fn(x[i, var_idx], obs[i, obs_keys])` :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available in self.adata_by_key. :param return_dense: Whether to force return count data .X as dense batches. This allows more efficient feature @@ -476,111 +497,26 @@ def generator( :param random_access: Whether to fully shuffle observations before batched access takes place. May slow down access compared randomized_batch_access and to no randomization. Do not use randomized_batch_access and random_access. + :param batch_schedule: Re + - "base" + - "balanced": idx_generator_kwarg need to include: + - "balance_obs": .obs column key to balance samples from each data set over. + Note that each data set must contain this column in its .obs table. + - "balance_damping": Damping to apply to class weighting induced by balance_obs. The class-wise + wise sampling probabilities become `max(balance_damping, (1. - frequency))` + - function: This can be a function that satisfies the interface. It will also receive idx_generator_kwarg. + :param kwargs: kwargs for idx_generator chosen. :return: Generator function which yields batch_size at every invocation. The generator returns a tuple of (.X, .obs). """ - idx, var_idx, batch_size = self._generator_helper(idx=idx, batch_size=batch_size) - if randomized_batch_access and random_access: - raise ValueError("Do not use randomized_batch_access and random_access.") - n_obs = len(idx) - remainder = n_obs % batch_size - n_batches = int(n_obs // batch_size + int(remainder > 0)) - - def idx_gen(): - """ - Yields index objects for one epoch of all data. - - These index objects are used by generators that have access to the data objects to build data batches. - - :returns: Tuple of: - - Ordering of observations in epoch. - - Batch start and end indices for batch based on ordering defined in first output. - """ - batch_starts_ends = [ - (int(x * batch_size), int(np.minimum((x * batch_size) + batch_size, n_obs))) - for x in np.arange(0, n_batches) - ] - batch_range = np.arange(0, len(batch_starts_ends)) - if randomized_batch_access: - np.random.shuffle(batch_range) - batch_starts_ends = [batch_starts_ends[i] for i in batch_range] - obs_idx = idx.copy() - if random_access: - np.random.shuffle(obs_idx) - yield obs_idx, batch_starts_ends - - return self._generator(idx_gen=idx_gen(), var_idx=var_idx, obs_keys=obs_keys), n_batches - - def generator_balanced( - self, - idx: Union[np.ndarray, None] = None, - balance_obs: Union[str, None] = None, - balance_damping: float = 0., - batch_size: int = 1, - obs_keys: List[str] = [], - **kwargs - ) -> iter: - """ - Yields a data set balanced generator. - - Yields one random batch per dataset. Assumes that data sets are annotated in .obs. - Uses self.dataset_weights if this are given to sample data sets with different frequencies. - Can additionally also balance across one meta data annotation within each data set. - - Assume you have a data set with two classes (A=80, B=20 cells) in a column named "cellontology_class". - The single batch for this data set produced by this generator in each epoch contains N cells. - If balance_obs is False, these N cells are the result of a draw without replacement from all 100 cells in this - dataset in which each cell receives the same weight / success probability of 1.0. - If balance_obs is True, these N cells are the result of a draw without replacement from all 100 cells in this - data set with individual success probabilities such that classes are balanced: 0.2 for A and 0.8 for B. - - :param idx: Global idx to query from store. These is an array with indicies corresponding to a contiuous index - along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of - self.adata_by_key. - :param balance_obs: .obs column key to balance samples from each data set over. - Note that each data set must contain this column in its .obs table. - :param balance_damping: Damping to apply to class weighting induced by balance_obs. The class-wise - wise sampling probabilities become `max(balance_damping, (1. - frequency))` - :param batch_size: Number of observations read from disk in each batched access (generator invocation). - :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available - in self.adata_by_key. - :return: Generator function which yields batch_size at every invocation. - The generator returns a tuple of (.X, .obs). - """ - idx, var_idx, batch_size = self._generator_helper(idx=idx, batch_size=batch_size) - - def idx_gen(): - batch_starts_ends = [] - idx_proc = [] - val_dataset = self.obs[self._adata_ids_sfaira.dataset].values - datasets = np.unique(val_dataset) - if self.dataset_weights is not None: - weights = np.array([self.dataset_weights[x] for x in datasets]) - p = weights / np.sum(weights) - datasets = np.random.choice(a=datasets, replace=True, size=len(datasets), p=p) - if balance_obs is not None: - val_meta = self.obs[balance_obs].values - for x in datasets: - idx_x = np.where(val_dataset == x)[0] - n_obs = len(idx_x) - batch_size_o = int(np.minimum(batch_size, n_obs)) - batch_starts_ends.append(np.array([(0, batch_size_o), ])) - if balance_obs is None: - p = np.ones_like(idx_x) / len(idx_x) - else: - if balance_obs not in self.obs.columns: - raise ValueError(f"did not find column {balance_obs} in {self.organism}") - val_meta_x = val_meta[idx_x] - class_freq = dict([(y, np.mean(val_meta_x == y)) for y in np.unique(val_meta_x)]) - class_freq_x_by_obs = np.array([class_freq[y] for y in val_meta_x]) - damped_freq_coefficient = np.maximum(balance_damping, (1. - class_freq_x_by_obs)) - p = np.ones_like(idx_x) / len(idx_x) * damped_freq_coefficient - idx_x_sample = np.random.choice(a=idx_x, replace=False, size=batch_size_o, p=p) - idx_proc.append(idx_x_sample) - idx_proc = np.asarray(idx_proc) - yield idx_proc, batch_starts_ends - - return self._generator(idx_gen=idx_gen(), var_idx=var_idx, obs_keys=obs_keys) + var_idx, batch_size, retrieval_batch_size = self._index_curation_helper( + batch_size=batch_size, retrival_batch_size=retrieval_batch_size) + batch_schedule_kwargs = {"randomized_batch_access": randomized_batch_access, + "random_access": random_access, + "retrieval_batch_size": retrieval_batch_size} + gen = self._get_generator(batch_schedule=batch_schedule, batch_size=batch_size, map_fn=map_fn, obs_idx=idx, + obs_keys=obs_keys, var_idx=var_idx, **batch_schedule_kwargs, **kwargs) + return gen @property @abc.abstractmethod @@ -592,18 +528,87 @@ def X(self): def obs(self) -> Union[pd.DataFrame]: pass + @property + def var(self) -> Union[pd.DataFrame]: + if self.genome_container is None: + var = pd.DataFrame({}, index=self.var_names) + else: + var = pd.DataFrame({ + "ensg": self.genome_container.ensembl, + "symbol": self.genome_container.symbols, + }, index=self.var_names) + return var + + def adata_slice(self, idx: np.ndarray, as_sparse: bool = True, **kwargs) -> anndata.AnnData: + """ + Assembles a slice of a store as a anndata instance using a generator. -class DistributedStoreH5ad(DistributedStoreSingleFeatureSpace): + Avoids loading entire data into memory first to then index. Uses .X_slice and loads var annotation from + .genome_container. + Note: this slice is a slice based on the subset already selected via previous subsetting on this instance. + + :param idx: Global idx to query from store. These is an array with indices corresponding to a contiuous index + along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of + self.adata_by_key. If None, all observations are selected. + :param as_sparse: Whether to format .X as a sparse matrix. + :param kwargs: kwargs to .generator(). + :return: Slice of data array. + """ + # Note: .obs is already in memory so can be sliced in memory without great disadvantages. + return anndata.AnnData( + X=self.X_slice(idx=idx, as_sparse=as_sparse, **kwargs), + obs=self.obs.iloc[idx, :], + var=self.var + ) + + def X_slice(self, idx: np.ndarray, as_sparse: bool = True, **kwargs) -> Union[np.ndarray, scipy.sparse.csr_matrix]: + """ + Assembles a slice of a store data matrix as a numpy / scipy array using a generator. + + Avoids loading entire data matrix first to then index, ie replaces: + + ``` python + # idx = some indices + x = store.X + x = x[idx,:] + ``` + + Note: this slice is a slice based on the subset already selected via previous subsetting on this instance. + + :param idx: Global idx to query from store. These is an array with indices corresponding to a contiuous index + along all observations in self.adata_by_key, ordered along a hypothetical concatenation along the keys of + self.adata_by_key. If None, all observations are selected. + :param as_sparse: Whether to return a sparse matrix. + :param kwargs: kwargs to .generator(). + :return: Slice of data array. + """ + batch_size = min(len(idx), 128) + g = self.generator(idx=idx, retrieval_batch_size=batch_size, return_dense=True, random_access=False, + randomized_batch_access=False, **kwargs) + shape = (idx.shape[0], self.n_vars) + if as_sparse: + x = scipy.sparse.csr_matrix(np.zeros(shape)) + else: + x = np.empty(shape) + counter = 0 + for x_batch, _ in g.iterator(): + batch_len = x_batch.shape[0] + x[counter:(counter + batch_len), :] = x_batch + counter += batch_len + return x + + +class DistributedStoreAnndata(DistributedStoreSingleFeatureSpace): in_memory: bool def __init__(self, in_memory: bool, **kwargs): - super(DistributedStoreH5ad, self).__init__(**kwargs) + super(DistributedStoreAnndata, self).__init__(**kwargs) self._x_as_dask = False self.in_memory = in_memory @property - def adata_sliced(self) -> Dict[str, anndata.AnnData]: + def _adata_sliced(self) -> Dict[str, anndata.AnnData]: """ Only exposes the subset and slices of the adata instances contained in ._adata_by_key defined in .indices. """ @@ -633,7 +638,7 @@ def X(self): assert np.all([isinstance(v.X, scipy.sparse.spmatrix) for v in self.adata_by_key.values()]) return scipy.sparse.vstack([v.X for v in self.adata_by_key.values()]) else: - raise NotImplementedError() + raise NotImplementedError("this operation is not efficient with backed objects") @property def obs(self) -> Union[pd.DataFrame]: @@ -647,81 +652,11 @@ def obs(self) -> Union[pd.DataFrame]: for k, v in self.indices.items() ], axis=0, join="inner", ignore_index=False, copy=False) - def _generator( - self, - idx_gen: iter, - var_idx: np.ndarray, - obs_keys: List[str] = [], - return_dense: bool = False, - ) -> iter: - """ - Yields data batches as defined by index sets emitted from index generator. - - :param idx_gen: Generator that yield two elements in each draw: - - np.ndarray: The cells to emit. - - List[Tuple[int, int]: Batch start and end indices. - :param var_idx: The features to emit. - :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available - in self.adata_by_key. - :return: Generator function which yields batch_size at every invocation. - The generator returns a tuple of (.X, .obs). - """ - adata_sliced = self.adata_sliced - # Speed up access to single object by skipping index overlap operations: - single_object = len(adata_sliced.keys()) == 1 - if not single_object: - idx_dict_global = dict([(k, set(v)) for k, v in self.indices_global.items()]) - - def generator(): - for idx, batch_starts_ends in idx_gen: - for s, e in batch_starts_ends: - idx_i = idx[s:e] - # Match adata objects that overlap to batch: - if single_object: - idx_i_dict = dict([(k, np.sort(idx_i)) for k in adata_sliced.keys()]) - else: - idx_i_set = set(idx_i) - idx_i_dict = dict([ - (k, np.sort(list(idx_i_set.intersection(v)))) - for k, v in idx_dict_global.items() - ]) - # Only retain non-empty. - idx_i_dict = dict([(k, v) for k, v in idx_i_dict.items() if len(v) > 0]) - # I) Prepare data matrix. - x = [ - adata_sliced[k].X[v, :] - for k, v in idx_i_dict.items() - ] - # Move from ArrayView to numpy if backed and dense: - x = [ - xx.toarray() if isinstance(xx, anndata._core.views.ArrayView) else xx - for xx in x - ] - # Do dense conversion now so that col-wise indexing is not slow, often, dense conversion - # would be done later anyway. - if return_dense: - x = [np.asarray(xx.todense()) if isinstance(xx, scipy.sparse.spmatrix) else xx for xx in x] - is_dense = True - else: - is_dense = isinstance(x[0], np.ndarray) - # Concatenate blocks in observation dimension: - if len(x) > 1: - if is_dense: - x = np.concatenate(x, axis=0) - else: - x = scipy.sparse.vstack(x) - else: - x = x[0] - if var_idx is not None: - x = x[:, var_idx] - # Prepare .obs. - obs = pd.concat([ - adata_sliced[k].obs[obs_keys].iloc[v, :] - for k, v in idx_i_dict.items() - ], axis=0, join="inner", ignore_index=True, copy=False) - yield x, obs - - return generator + def _get_generator(self, return_dense: bool = False, **kwargs) -> iter: + idx_dict_global = dict([(k1, (v1, v2)) + for (k1, v1), v2 in zip(self.indices_global.items(), self.indices.values())]) + return GeneratorAnndata(adata_dict=self._adata_sliced, idx_dict_global=idx_dict_global, + return_dense=return_dense, **kwargs) class DistributedStoreDao(DistributedStoreSingleFeatureSpace): @@ -803,49 +738,5 @@ def obs(self) -> pd.DataFrame: for k, v in self.indices.items() ], axis=0, join="inner", ignore_index=True, copy=False) - def _generator( - self, - idx_gen: iter, - var_idx: np.ndarray, - obs_keys: List[str] = [], - ) -> iter: - """ - Yields data batches as defined by index sets emitted from index generator. - - :param idx_gen: Generator that yield two elements in each draw: - - np.ndarray: The cells to emit. - - List[Tuple[int, int]: Batch start and end indices. - :param var_idx: The features to emit. - :param obs_keys: .obs columns to return in the generator. These have to be a subset of the columns available - in self.adata_by_key. - :return: Generator function which yields batch_size at every invocation. - The generator returns a tuple of (.X, .obs). - """ - # Normalise cell indices such that each organism is indexed starting at zero: - # This is required below because each organism is represented as its own dask array. - # TODO this might take a lot of time as the dask array is built. - t0 = time.time() - x = self.X - print(f"init X: {time.time() - t0}") - t0 = time.time() - obs = self.obs[obs_keys] - # Redefine index so that .loc indexing can be used instead of .iloc indexing: - obs.index = np.arange(0, obs.shape[0]) - print(f"init obs: {time.time() - t0}") - - def generator(): - # Can all data sets corresponding to one organism as a single array because they share the second dimension - # and dask keeps expression data and obs out of memory. - for idx, batch_starts_ends in idx_gen: - x_temp = x[idx, :] - obs_temp = obs.loc[obs.index[idx], :] # TODO better than iloc? - for s, e in batch_starts_ends: - x_i = x_temp[s:e, :] - if var_idx is not None: - x_i = x_i[:, var_idx] - # Exploit fact that index of obs is just increasing list of integers, so we can use the .loc[] - # indexing instead of .iloc[]: - obs_i = obs_temp.loc[obs_temp.index[s:e], :] - yield x_i, obs_i - - return generator + def _get_generator(self, **kwargs) -> GeneratorDask: + return GeneratorDask(x=self.X, obs=self.obs, **kwargs) diff --git a/sfaira/data/utils.py b/sfaira/data/utils.py index a9108fb19..ee44a2055 100644 --- a/sfaira/data/utils.py +++ b/sfaira/data/utils.py @@ -52,7 +52,7 @@ def map_celltype_to_ontology( queries = [queries] oc = OntologyContainerSfaira() cu = CelltypeUniverse( - cl=oc.cellontology_class, + cl=oc.cell_type, uberon=oc.organ, organism=organism, **kwargs diff --git a/sfaira/data/utils_scripts/create_target_universes.py b/sfaira/data/utils_scripts/create_target_universes.py index ec26c17da..7cde5d9af 100644 --- a/sfaira/data/utils_scripts/create_target_universes.py +++ b/sfaira/data/utils_scripts/create_target_universes.py @@ -36,7 +36,7 @@ celltypes_found = celltypes_found.union( set(store.adatas[k].obs[col_name_annot].values[idx].tolist()) ) - celltypes_found = sorted(list(celltypes_found - {store._adata_ids_sfaira.unknown_celltype_identifier, + celltypes_found = sorted(list(celltypes_found - {store._adata_ids_sfaira.unknown_metadata_identifier, store._adata_ids_sfaira.not_a_cell_celltype_identifier})) if len(celltypes_found) == 0: print(f"WARNING: No cells found for {organism} {organ}, skipping.") diff --git a/sfaira/data/utils_scripts/streamline_selected.py b/sfaira/data/utils_scripts/streamline_selected.py index 6fc201978..dba2bef95 100644 --- a/sfaira/data/utils_scripts/streamline_selected.py +++ b/sfaira/data/utils_scripts/streamline_selected.py @@ -35,7 +35,10 @@ clean_obs=False, clean_var=True, clean_uns=True, - clean_obs_names=False + clean_obs_names=False, + keep_orginal_obs=False, + keep_symbol_obs=True, + keep_id_obs=True, ) ds.collapse_counts() assert len(ds.dataset_groups) == 1, len(ds.dataset_groups) diff --git a/sfaira/data/utils_scripts/test_store.py b/sfaira/data/utils_scripts/test_store.py index 56249fdda..64bd0ecdd 100644 --- a/sfaira/data/utils_scripts/test_store.py +++ b/sfaira/data/utils_scripts/test_store.py @@ -57,9 +57,10 @@ def time_gen(_store, store_format, kwargs) -> List[float]: if kwargs["var_subset"]: gc = sfaira.versions.genomes.genomes.GenomeContainer(assembly="Homo_sapiens.GRCh38.102") gc.subset(symbols=["VTA1", "MLXIPL", "BAZ1B", "RANBP9", "PPARGC1A", "DDX25", "CRYAB"]) - _store.genome_containers = gc + _store.genome_container = gc del kwargs["var_subset"] - _gen = _store.generator(**kwargs)() + _gen, _ = _store.iterator(**kwargs) + _gen = _gen() _measurements = [] for _ in range(N_DRAWS): _t0 = time.time() @@ -103,7 +104,7 @@ def get_idx_dataset_start(_store, k_target): t0 = time.time() store = sfaira.data.load_store(cache_path=path_store, store_format=store_type_i) # Include initialisation of generator in timing to time overhead generated here. - _ = store.generator() + _, _ = store.generator() time_measurements_initiate[store_type_i].append(time.time() - t0) memory_measurements_initiate[store_type_i].append(np.sum(list(store.adata_memory_footprint.values()))) diff --git a/sfaira/estimators/keras.py b/sfaira/estimators/keras.py index 9452d5fbe..5888134ec 100644 --- a/sfaira/estimators/keras.py +++ b/sfaira/estimators/keras.py @@ -7,14 +7,19 @@ import tensorflow as tf except ImportError: tf = None -from typing import Union +from typing import List, Union import os import warnings from tqdm import tqdm from sfaira.consts import AdataIdsSfaira, OCS, AdataIds +from sfaira.data.store.base import DistributedStoreBase +from sfaira.data.store.generators import GeneratorSingle +from sfaira.data.store.multi_store import DistributedStoresAnndata from sfaira.data.store.single_store import DistributedStoreSingleFeatureSpace from sfaira.models import BasicModelKeras +from sfaira.models.celltype import BasicModelKerasCelltype +from sfaira.models.embedding import BasicModelKerasEmbedding from sfaira.versions.metadata import CelltypeUniverse, OntologyCl, OntologyObo from sfaira.versions.topologies import TopologyContainer from .losses import LossLoglikelihoodNb, LossLoglikelihoodGaussian, LossCrossentropyAgg, KLLoss @@ -36,13 +41,129 @@ def prepare_sf(x): return sf +def split_idx(data: DistributedStoreSingleFeatureSpace, test_split, val_split): + """ + Split training and evaluation data. + """ + np.random.seed(1) + all_idx = np.arange(0, data.n_obs) # n_obs is both a property of AnnData and DistributedStoreBase + if isinstance(test_split, float) or isinstance(test_split, int): + idx_test = np.sort(np.random.choice( + a=all_idx, + size=round(data.n_obs * test_split), + replace=False, + )) + elif isinstance(test_split, dict): + in_test = np.ones((data.n_obs,), dtype=int) == 1 + for k, v in test_split.items(): + if isinstance(v, bool) or isinstance(v, int) or isinstance(v, list): + v = [v] + idx = data.get_subset_idx(attr_key=k, values=v, excluded_values=None) + # Build continuous vector across all sliced data sets and establish which observations are kept + # in subset. + in_test_k = np.ones((data.n_obs,), dtype=int) == 0 + counter = 0 + for kk, vv in data.indices.items(): + if kk in idx.keys() and len(idx[kk]) > 0: + in_test_k[np.where([x in idx[kk] for x in vv])[0] + counter] = True + counter += len(vv) + in_test = np.logical_and(in_test, in_test_k) + idx_test = np.sort(np.where(in_test)[0]) + else: + raise ValueError("type of test_split %s not recognized" % type(test_split)) + print(f"Found {len(idx_test)} out of {data.n_obs} cells that correspond to test data set") + assert len(idx_test) < data.n_obs, f"test set covers full data set, apply a more restrictive test " \ + f"data definiton ({len(idx_test)}, {data.n_obs})" + idx_train_eval = np.array([x for x in all_idx if x not in idx_test]) + np.random.seed(1) + idx_eval = np.sort(np.random.choice( + a=idx_train_eval, + size=round(len(idx_train_eval) * val_split), + replace=False + )) + idx_train = np.sort([x for x in idx_train_eval if x not in idx_eval]) + + # Check that none of the train, test, eval partitions are empty + if not len(idx_test): + warnings.warn("Test partition is empty!") + if not len(idx_eval): + raise ValueError("The evaluation dataset is empty.") + if not len(idx_train): + raise ValueError("The train dataset is empty.") + return idx_train, idx_eval, idx_test + + +def process_tf_dataset(dataset, mode: str, batch_size: int, cache: bool, shuffle_buffer_size: int, prefetch)\ + -> tf.data.Dataset: + if cache: + dataset = dataset.cache() + if mode in ['train', 'train_val']: + dataset = dataset.repeat() + if shuffle_buffer_size > 0: + # Only shuffle in train modes + dataset = dataset.shuffle(buffer_size=shuffle_buffer_size, seed=None, reshuffle_each_iteration=True) + dataset = dataset.batch(batch_size).prefetch(prefetch) + return dataset + + +def get_optimizer(optimizer: str, lr: float): + if optimizer.lower() == "adam": + return tf.keras.optimizers.Adam(learning_rate=lr) + elif optimizer.lower() == "sgd": + return tf.keras.optimizers.SGD(learning_rate=lr) + elif optimizer.lower() == "rmsprop": + return tf.keras.optimizers.RMSprop(learning_rate=lr) + elif optimizer.lower() == "adagrad": + return tf.keras.optimizers.Adagrad(learning_rate=lr) + else: + assert False + + +def assemble_cbs(patience, lr_schedule_factor, lr_schedule_patience, lr_schedule_min_lr, lr_verbose, log_dir, + callbacks) -> List[tf.keras.callbacks.Callback]: + cbs = [tf.keras.callbacks.TerminateOnNaN()] + if patience is not None and patience > 0: + cbs.append(tf.keras.callbacks.EarlyStopping( + monitor='val_loss', + patience=patience, + restore_best_weights=True, + verbose=1 + )) + if lr_schedule_factor is not None and lr_schedule_factor < 1.: + cbs.append(tf.keras.callbacks.ReduceLROnPlateau( + monitor='val_loss', + factor=lr_schedule_factor, + patience=lr_schedule_patience, + min_lr=lr_schedule_min_lr, + verbose=lr_verbose + )) + if log_dir is not None: + cbs.append(tf.keras.callbacks.TensorBoard( + log_dir=log_dir, + histogram_freq=0, + batch_size=32, + write_graph=False, + write_grads=False, + write_images=False, + embeddings_freq=0, + embeddings_layer_names=None, + embeddings_metadata=None, + embeddings_data=None, + update_freq='epoch' + )) + if callbacks is not None: + # callbacks needs to be a list + cbs += callbacks + return cbs + + class EstimatorKeras: """ Estimator base class for keras models. """ - data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace] + data: DistributedStoreSingleFeatureSpace model: Union[BasicModelKeras, None] - topology_container: Union[TopologyContainer, None] + topology_container: TopologyContainer model_id: Union[str, None] weights: Union[np.ndarray, None] model_dir: Union[str, None] @@ -55,7 +176,7 @@ class EstimatorKeras: def __init__( self, - data: Union[anndata.AnnData, np.ndarray, DistributedStoreSingleFeatureSpace], + data: Union[anndata.AnnData, List[anndata.AnnData], DistributedStoreSingleFeatureSpace], model_dir: Union[str, None], model_class: str, model_id: Union[str, None], @@ -64,14 +185,23 @@ def __init__( cache_path: str = os.path.join('cache', ''), adata_ids: AdataIds = AdataIdsSfaira() ): - self.data = data self.model = None self.model_dir = model_dir self.model_id = model_id self.model_class = model_class self.topology_container = model_topology + if isinstance(data, anndata.AnnData): + data = DistributedStoresAnndata(adatas=data).stores[self.organism] + if isinstance(data, list) or isinstance(data, tuple): + for x in data: + assert isinstance(x, anndata.AnnData), f"found element in list that was not anndata but {type(x)}" + data = DistributedStoresAnndata(adatas=data).stores[self.organism] + self.data = data # Prepare store with genome container sub-setting: - if isinstance(self.data, DistributedStoreSingleFeatureSpace): + # This class is tailored for DistributedStoreSingleFeatureSpace but we test for the base class here in the + # constructor so that genome_container can also be set in inheriting classes that may be centred around + # different child classes of DistributedStoreBase. + if isinstance(self.data, DistributedStoreBase): self.data.genome_container = self.topology_container.gc self.history = None @@ -89,7 +219,7 @@ def model_type(self): @property def organism(self): - return self.topology_container.organism + return {"homo_sapiens": "human", "mus_musculus": "mouse"}[self.topology_container.organism] def load_pretrained_weights(self): """ @@ -179,20 +309,32 @@ def _assert_md5_sum( raise ValueError("md5 of %s did not match expectation" % fn) @abc.abstractmethod - def _get_dataset( - self, - idx: Union[np.ndarray, None], - batch_size: Union[int, None], - mode: str, - shuffle_buffer_size: int, - cache_full: bool, - weighted: bool, - retrieval_batch_size: int, - randomized_batch_access: bool, - prefetch: Union[int, None], - ) -> tf.data.Dataset: + def _get_generator(self, **kwargs) -> GeneratorSingle: + """ + Yield a generator based on which a tf dataset can be built. + """ pass + @abc.abstractmethod + def _tf_dataset_kwargs(self, mode: str): + pass + + def get_one_time_tf_dataset(self, idx, mode, batch_size=None, prefetch=None): + batch_size = 128 if batch_size is None else batch_size + prefetch = 10 if prefetch is None else prefetch + tf_kwargs = { + "batch_size": batch_size, + "cache": False, + "prefetch": prefetch, + "shuffle_buffer_size": 0, + } + train_gen = self._get_generator(idx=idx, mode=mode, retrieval_batch_size=128, + randomized_batch_access=False) + train_tf_dataset_kwargs = self._tf_dataset_kwargs(mode=mode) + train_dataset = train_gen.adaptor(generator_type="tensorflow", **train_tf_dataset_kwargs) + train_dataset = process_tf_dataset(dataset=train_dataset, mode=mode, **tf_kwargs) + return train_dataset + def _get_class_dict( self, obs_key: str @@ -207,52 +349,6 @@ def _get_class_dict( label_dict.update({label: float(i)}) return label_dict - def _prepare_data_matrix(self, idx: Union[np.ndarray, None]) -> scipy.sparse.csr_matrix: - """ - Subsets observations x features matrix in .data to observation indices (idx, the split) and features defined - by topology. - - :param idx: Observation index split. - :return: Data matrix - """ - # Check that AnnData is not backed. If backed, assume that these processing steps were done before. - if self.data.isbacked: - raise ValueError("tried running backed AnnData object through standard pipeline") - - else: - # Convert data matrix to csr matrix - if isinstance(self.data.X, np.ndarray): - # Change NaN to zero. This occurs for example in concatenation of anndata instances. - if np.any(np.isnan(self.data.X)): - self.data.X[np.isnan(self.data.X)] = 0 - x = scipy.sparse.csr_matrix(self.data.X) - elif isinstance(self.data.X, scipy.sparse.spmatrix): - x = self.data.X.tocsr() - else: - raise ValueError("data type %s not recognized" % type(self.data.X)) - - # Subset cells by provided idx - if idx is not None: - x = x[idx, :] - - # If the feature space is already mapped to the right reference, return the data matrix immediately - if self.data.n_vars != self.topology_container.n_var or \ - not np.all(self.data.var[self._adata_ids.gene_id_ensembl] == self.topology_container.gc.ensembl): - # Compute indices of genes to keep - data_ids = self.data.var[self._adata_ids.gene_id_ensembl].values.tolist() - target_ids = self.topology_container.gc.ensembl - idx_map = np.array([data_ids.index(z) for z in target_ids]) - # Assert that each ID from target IDs appears exactly once in data IDs: - assert np.all([z in data_ids for z in target_ids]), "not all target feature IDs found in data" - assert np.all([np.sum(z == np.array(data_ids)) <= 1. for z in target_ids]), \ - "duplicated target feature IDs exist in data" - # Map feature space. - x = x[:, idx_map] - print(f"found {len(idx_map)} intersecting features between {x.shape[1]} features in input data set and" - f" {self.topology_container.n_var} features in reference genome") - print(f"found {x.shape[0]} observations") - return x - @abc.abstractmethod def _get_loss(self): pass @@ -272,59 +368,21 @@ def _compile_models( ) def split_train_val_test(self, val_split: float, test_split: Union[float, dict]): - # Split training and evaluation data. - np.random.seed(1) - all_idx = np.arange(0, self.data.n_obs) # n_obs is both a property of AnnData and DistributedStoreBase - if isinstance(test_split, float) or isinstance(test_split, int): - self.idx_test = np.sort(np.random.choice( - a=all_idx, - size=round(self.data.n_obs * test_split), - replace=False, - )) - elif isinstance(test_split, dict): - in_test = np.ones((self.data.n_obs,), dtype=int) == 1 - for k, v in test_split.items(): - if isinstance(v, bool) or isinstance(v, int) or isinstance(v, list): - v = [v] - if isinstance(self.data, anndata.AnnData): - if k not in self.data.obs.columns: - raise ValueError(f"Did not find column {k} used to define test set in self.data.") - in_test = np.logical_and(in_test, np.array([x in v for x in self.data.obs[k].values])) - elif isinstance(self.data, DistributedStoreSingleFeatureSpace): - idx = self.data.get_subset_idx(attr_key=k, values=v, excluded_values=None) - # Build continuous vector across all sliced data sets and establish which observations are kept - # in subset. - in_test_k = np.ones((self.data.n_obs,), dtype=int) == 0 - counter = 0 - for kk, vv in self.data.indices.items(): - if kk in idx.keys() and len(idx[kk]) > 0: - in_test_k[np.where([x in idx[kk] for x in vv])[0] + counter] = True - counter += len(vv) - in_test = np.logical_and(in_test, in_test_k) - else: - assert False - self.idx_test = np.sort(np.where(in_test)[0]) - else: - raise ValueError("type of test_split %s not recognized" % type(test_split)) - print(f"Found {len(self.idx_test)} out of {self.data.n_obs} cells that correspond to test data set") - assert len(self.idx_test) < self.data.n_obs, "test set covers full data set, apply a more restrictive test " \ - "data definiton" - idx_train_eval = np.array([x for x in all_idx if x not in self.idx_test]) - np.random.seed(1) - self.idx_eval = np.sort(np.random.choice( - a=idx_train_eval, - size=round(len(idx_train_eval) * val_split), - replace=False - )) - self.idx_train = np.sort([x for x in idx_train_eval if x not in self.idx_eval]) + """ + Split indices in store into train, valiation and test split. + """ + idx_train, idx_eval, idx_test = split_idx(data=self.data, test_split=test_split, val_split=val_split) + self.idx_train = idx_train + self.idx_eval = idx_eval + self.idx_test = idx_test - # Check that none of the train, test, eval partitions are empty - if not len(self.idx_test): - warnings.warn("Test partition is empty!") - if not len(self.idx_eval): - raise ValueError("The evaluation dataset is empty.") - if not len(self.idx_train): - raise ValueError("The train dataset is empty.") + def _process_idx_for_eval(self, idx): + """ + Defaults to all observations if no indices are defined. + """ + if idx is None: + idx = np.arange(0, self.data.n_obs) + return idx def train( self, @@ -383,102 +441,41 @@ def train( :param verbose: :return: """ - # Set optimizer - if optimizer.lower() == "adam": - optim = tf.keras.optimizers.Adam(learning_rate=lr) - elif optimizer.lower() == "sgd": - optim = tf.keras.optimizers.SGD(learning_rate=lr) - elif optimizer.lower() == "rmsprop": - optim = tf.keras.optimizers.RMSprop(learning_rate=lr) - elif optimizer.lower() == "adagrad": - optim = tf.keras.optimizers.Adagrad(learning_rate=lr) - else: - assert False # Save training settings to allow model restoring. - self.train_hyperparam = { - "epochs": epochs, - "max_steps_per_epoch": max_steps_per_epoch, - "optimizer": optimizer, - "lr": lr, - "batch_size": batch_size, - "validation_split": validation_split, - "validation_batch_size": validation_batch_size, - "max_validation_steps": max_validation_steps, - "patience": patience, - "lr_schedule_min_lr": lr_schedule_min_lr, - "lr_schedule_factor": lr_schedule_factor, - "lr_schedule_patience": lr_schedule_patience, - "log_dir": log_dir, - "weighted": weighted - } + self.train_hyperparam = {"epochs": epochs, "max_steps_per_epoch": max_steps_per_epoch, "optimizer": optimizer, + "lr": lr, "batch_size": batch_size, "validation_split": validation_split, + "validation_batch_size": validation_batch_size, + "max_validation_steps": max_validation_steps, "patience": patience, + "lr_schedule_min_lr": lr_schedule_min_lr, "lr_schedule_factor": lr_schedule_factor, + "lr_schedule_patience": lr_schedule_patience, "log_dir": log_dir, "weighted": weighted} # Set callbacks. - cbs = [tf.keras.callbacks.TerminateOnNaN()] - if patience is not None and patience > 0: - cbs.append(tf.keras.callbacks.EarlyStopping( - monitor='val_loss', - patience=patience, - restore_best_weights=True, - verbose=verbose - )) - if lr_schedule_factor is not None and lr_schedule_factor < 1.: - cbs.append(tf.keras.callbacks.ReduceLROnPlateau( - monitor='val_loss', - factor=lr_schedule_factor, - patience=lr_schedule_patience, - min_lr=lr_schedule_min_lr, - verbose=verbose - )) - if log_dir is not None: - cbs.append(tf.keras.callbacks.TensorBoard( - log_dir=log_dir, - histogram_freq=0, - batch_size=32, - write_graph=False, - write_grads=False, - write_images=False, - embeddings_freq=0, - embeddings_layer_names=None, - embeddings_metadata=None, - embeddings_data=None, - update_freq='epoch' - )) - - if callbacks is not None: - # callbacks needs to be a list - cbs += callbacks + cbs = assemble_cbs(patience=patience, lr_schedule_factor=lr_schedule_factor, + lr_schedule_patience=lr_schedule_patience, lr_schedule_min_lr=lr_schedule_min_lr, + lr_verbose=verbose, log_dir=log_dir, callbacks=callbacks) # Check randomisation settings: if shuffle_buffer_size is not None and shuffle_buffer_size > 0 and randomized_batch_access: raise ValueError("You are using shuffle_buffer_size and randomized_batch_access, this is likely not " "intended.") + shuffle_buffer_size = shuffle_buffer_size if shuffle_buffer_size is not None else 0 if cache_full and randomized_batch_access: raise ValueError("You are using cache_full and randomized_batch_access, this is likely not intended.") self.split_train_val_test(val_split=validation_split, test_split=test_split) - self._compile_models(optimizer=optim) - shuffle_buffer_size = shuffle_buffer_size if shuffle_buffer_size is not None else 0 - train_dataset = self._get_dataset( - idx=self.idx_train, - batch_size=batch_size, - retrieval_batch_size=retrieval_batch_size, - mode='train', - shuffle_buffer_size=min(shuffle_buffer_size, len(self.idx_train)), - weighted=weighted, - cache_full=cache_full, - randomized_batch_access=randomized_batch_access, - prefetch=prefetch, - ) - eval_dataset = self._get_dataset( - idx=self.idx_eval, - batch_size=validation_batch_size, - retrieval_batch_size=retrieval_batch_size, - mode='train_val', - shuffle_buffer_size=min(shuffle_buffer_size, len(self.idx_eval)), - weighted=weighted, - cache_full=cache_full, - randomized_batch_access=randomized_batch_access, - prefetch=prefetch, - ) + self._compile_models(optimizer=get_optimizer(optimizer=optimizer, lr=lr)) + + tf_kwargs = {"batch_size": batch_size, "cache": cache_full, "prefetch": prefetch, + "shuffle_buffer_size": min(shuffle_buffer_size, len(self.idx_train))} + train_gen = self._get_generator(idx=self.idx_train, mode='train', retrieval_batch_size=retrieval_batch_size, + randomized_batch_access=randomized_batch_access, weighted=weighted) + train_tf_dataset_kwargs = self._tf_dataset_kwargs(mode="train") + train_dataset = train_gen.adaptor(generator_type="tensorflow", **train_tf_dataset_kwargs) + train_dataset = process_tf_dataset(dataset=train_dataset, mode="train", **tf_kwargs) + val_gen = self._get_generator(idx=self.idx_train, mode='train', retrieval_batch_size=retrieval_batch_size, + randomized_batch_access=randomized_batch_access, weighted=weighted) + val_tf_dataset_kwargs = self._tf_dataset_kwargs(mode="train_val") + val_dataset = val_gen.adaptor(generator_type="tensorflow", **val_tf_dataset_kwargs) + val_dataset = process_tf_dataset(dataset=val_dataset, mode="train", **tf_kwargs) steps_per_epoch = min(max(len(self.idx_train) // batch_size, 1), max_steps_per_epoch) validation_steps = min(max(len(self.idx_eval) // validation_batch_size, 1), max_validation_steps) @@ -488,7 +485,7 @@ def train( epochs=epochs, steps_per_epoch=steps_per_epoch, callbacks=cbs, - validation_data=eval_dataset, + validation_data=val_dataset, validation_steps=validation_steps, verbose=verbose ).history @@ -515,6 +512,8 @@ class EstimatorKerasEmbedding(EstimatorKeras): Estimator class for the embedding model. """ + model: Union[BasicModelKerasEmbedding, None] + def __init__( self, data: Union[anndata.AnnData, np.ndarray, DistributedStoreSingleFeatureSpace], @@ -563,220 +562,46 @@ def init_model( override_hyperpar=override_hyperpar ) - @staticmethod - def _get_output_dim(n_features, model_type, mode='train'): + def _tf_dataset_kwargs(self, mode: str): + # Determine model type [ae, vae(iaf, vamp)] + model_type = "vae" if self.model_type[:3] == "vae" else "ae" if mode == 'predict': # Output shape is same for predict mode regardless of model type output_types = (tf.float32, tf.float32), - output_shapes = (n_features, ()), + output_shapes = (self.data.n_vars, ()), elif model_type == "vae": output_types = ((tf.float32, tf.float32), (tf.float32, tf.float32)) - output_shapes = ((n_features, ()), (n_features, ())) + output_shapes = ((self.data.n_vars, ()), (self.data.n_vars, ())) else: output_types = ((tf.float32, tf.float32), tf.float32) - output_shapes = ((n_features, ()), n_features) - - return output_types, output_shapes + output_shapes = ((self.data.n_vars, ()), self.data.n_vars) + return {"output_types": output_types, "output_shapes": output_shapes} - def _get_base_generator( + def _get_generator( self, - generator_helper, idx: Union[np.ndarray, None], - batch_size: int, + mode: str, + retrieval_batch_size: int, randomized_batch_access: bool, + **kwargs ): - """ - Yield a basic generator based on which a tf dataset can be built. - - The signature of this generator can be modified through generator_helper. - - :param generator_helper: Python function that should take (x_sample,) as an input: - - - x_sample is a gene expression vector of a cell - :param idx: Indicies of data set to include in generator. - :param batch_size: Number of observations read from disk in each batched access. - :param randomized_batch_access: Whether to randomize batches during reading (in generator). Lifts necessity of - using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there - is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest - changes in batch composition. - :return: - """ - if idx is None: - idx = np.arange(0, self.data.n_obs) - - # Prepare data reading according to whether anndata is backed or not: - if self.using_store: - generator_raw, _ = self.data.generator( - idx=idx, - batch_size=batch_size, - obs_keys=[], - return_dense=True, - randomized_batch_access=randomized_batch_access, - ) - - def generator(): - for z in generator_raw(): - x_sample = z[0] - if isinstance(x_sample, scipy.sparse.csr_matrix): - x_sample = x_sample.todense() - x_sample = np.asarray(x_sample) - for i in range(x_sample.shape[0]): - yield generator_helper(x_sample=x_sample[i]) - - n_features = self.data.n_vars - n_samples = self.data.n_obs - else: - x = self.data.X if self.data.isbacked else self._prepare_data_matrix(idx=idx) - indices = idx if self.data.isbacked else range(x.shape[0]) - n_obs = len(indices) - remainder = n_obs % batch_size - batch_starts_ends = [ - (int(x * batch_size), int(x * batch_size) + batch_size) - for x in np.arange(0, n_obs // batch_size + int(remainder > 0)) - ] - - def generator(): - is_sparse = isinstance(x[0, :], scipy.sparse.spmatrix) - for s, e in batch_starts_ends: - x_sample = np.asarray(x[indices[s:e], :].todense()) if is_sparse \ - else x[indices[s:e], :] - for i in range(x_sample.shape[0]): - yield generator_helper(x_sample=x_sample[i]) - - n_features = x.shape[1] - n_samples = x.shape[0] - - return generator, n_samples, n_features - - def _get_dataset( - self, - idx: Union[np.ndarray, None], - batch_size: Union[int, None], - mode: str, - shuffle_buffer_size: int = int(1e7), - cache_full: bool = False, - weighted: bool = False, - retrieval_batch_size: int = 128, - randomized_batch_access: bool = False, - prefetch: Union[int, None] = 1, - ) -> tf.data.Dataset: - """ - - :param idx: - :param batch_size: - :param mode: - :param shuffle_buffer_size: - :param weighted: Whether to use weights. Not implemented for embedding models yet. - :param retrieval_batch_size: Number of observations read from disk in each batched access. - :param randomized_batch_access: Whether to randomize batches during reading (in generator). Lifts necessity of - using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there - is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest - changes in batch composition. - :return: - """ - # Determine model type [ae, vae(iaf, vamp)] + # Define constants used by map_fn in outer name space so that they are not created for each sample. model_type = "vae" if self.model_type[:3] == "vae" else "ae" - if mode in ['train', 'train_val', 'eval', 'predict']: - def generator_helper(x_sample): - sf_sample = prepare_sf(x=x_sample)[0] - if mode == 'predict': - return (x_sample, sf_sample), - elif model_type == "vae": - return (x_sample, sf_sample), (x_sample, sf_sample) - else: - return (x_sample, sf_sample), x_sample - - generator, n_samples, n_features = self._get_base_generator( - generator_helper=generator_helper, - idx=idx, - batch_size=retrieval_batch_size, - randomized_batch_access=randomized_batch_access, - ) - output_types, output_shapes = self._get_output_dim(n_features=n_features, model_type=model_type, mode=mode) - dataset = tf.data.Dataset.from_generator( - generator=generator, - output_types=output_types, - output_shapes=output_shapes - ) - if cache_full: - dataset = dataset.cache() - # Only shuffle in train modes - if mode in ['train', 'train_val']: - dataset = dataset.repeat() - if shuffle_buffer_size is not None and shuffle_buffer_size > 0: - dataset = dataset.shuffle( - buffer_size=min(n_samples, shuffle_buffer_size), - seed=None, - reshuffle_each_iteration=True) - if prefetch is None: - prefetch = tf.data.AUTOTUNE - dataset = dataset.batch(batch_size, drop_remainder=False).prefetch(prefetch) - - return dataset - - elif mode == 'gradient_method': # TODO depreceate this code - # Prepare data reading according to whether anndata is backed or not: - cell_to_class = self._get_class_dict(obs_key=self._adata_ids.cellontology_class) - if self.using_store: - n_features = self.data.n_vars - generator_raw = self.data.generator( - idx=idx, - batch_size=1, - obs_keys=[self._adata_ids.cellontology_class], - return_dense=True, - ) - - def generator(): - for z in generator_raw(): - x_sample = z[0] - if isinstance(x_sample, scipy.sparse.csr_matrix): - x_sample = x_sample.todense() - x_sample = np.asarray(x_sample).flatten() - sf_sample = prepare_sf(x=x_sample)[0] - y_sample = z[1][self._adata_ids.cellontology_class].values[0] - yield (x_sample, sf_sample), (x_sample, cell_to_class[y_sample]) - - elif isinstance(self.data, anndata.AnnData) and self.data.isbacked: - if idx is None: - idx = np.arange(0, self.data.n_obs) - n_features = self.data.X.shape[1] - - def generator(): - sparse = isinstance(self.data.X[0, :], scipy.sparse.spmatrix) - for i in idx: - x_sample = self.data.X[i, :].toarray().flatten() if sparse else self.data.X[i, :].flatten() - sf_sample = prepare_sf(x=x_sample)[0] - y_sample = self.data.obs[self._adata_ids.cellontology_id][i] - yield (x_sample, sf_sample), (x_sample, cell_to_class[y_sample]) + def map_fn(x_sample, obs_sample): + x_sample = np.asarray(x_sample).flatten() + sf_sample = prepare_sf(x=x_sample).flatten()[0] + if mode == 'predict': + output = (x_sample, sf_sample), + elif model_type == "vae": + output = (x_sample, sf_sample), (x_sample, sf_sample), else: - if idx is None: - idx = np.arange(0, self.data.n_obs) - x = self._prepare_data_matrix(idx=idx) - sf = prepare_sf(x=x) - y = self.data.obs[self._adata_ids.cellontology_class].values[idx] - # for gradients per celltype in compute_gradients_input() - n_features = x.shape[1] - - def generator(): - for i in range(x.shape[0]): - yield (x[i, :].toarray().flatten(), sf[i]), (x[i, :].toarray().flatten(), cell_to_class[y[i]]) - - output_types, output_shapes = self._get_output_dim(n_features, 'vae') - dataset = tf.data.Dataset.from_generator( - generator=generator, - output_types=output_types, - output_shapes=output_shapes - ) - dataset = dataset.shuffle( - buffer_size=shuffle_buffer_size, - seed=None, - reshuffle_each_iteration=True - ).batch(batch_size, drop_remainder=False).prefetch(prefetch) - - return dataset + output = (x_sample, sf_sample), x_sample + return output - else: - raise ValueError(f'Mode {mode} not recognised. Should be "train", "eval" or" predict"') + g = self.data.generator(idx=idx, retrieval_batch_size=retrieval_batch_size, obs_keys=[], map_fn=map_fn, + return_dense=True, randomized_batch_access=randomized_batch_access, + random_access=False) + return g def _get_loss(self): if self.topology_container.topology["hyper_parameters"]["output_layer"] in [ @@ -827,15 +652,9 @@ def evaluate_any(self, idx, batch_size: int = 128, max_steps: int = np.inf): :param max_steps: Maximum steps before evaluation round is considered complete. :return: Dictionary of metric names and values. """ - if idx is None or idx.any(): # true if the array is not empty or if the passed value is None - idx = np.arange(0, self.data.n_obs) if idx is None else idx - dataset = self._get_dataset( - idx=idx, - batch_size=batch_size, - mode='eval', - retrieval_batch_size=128, - shuffle_buffer_size=0, - ) + idx = self._process_idx_for_eval(idx=idx) + if idx is not None: + dataset = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='eval') steps = min(max(len(idx) // batch_size, 1), max_steps) results = self.model.training_model.evaluate(x=dataset, steps=steps) return dict(zip(self.model.training_model.metrics_names, results)) @@ -852,7 +671,11 @@ def evaluate(self, batch_size: int = 128, max_steps: int = np.inf): :param max_steps: Maximum steps before evaluation round is considered complete. :return: Dictionary of metric names and values. """ - return self.evaluate_any(idx=self.idx_test, batch_size=batch_size, max_steps=max_steps) + idx = self._process_idx_for_eval(idx=self.idx_test) + if idx is not None: + return self.evaluate_any(idx=self.idx_test, batch_size=batch_size, max_steps=max_steps) + else: + return {} def predict(self, batch_size: int = 128): """ @@ -861,53 +684,25 @@ def predict(self, batch_size: int = 128): :return: prediction """ - if self.idx_test is None or self.idx_test.any(): - dataset = self._get_dataset( - idx=self.idx_test, - batch_size=batch_size, - mode='predict', - retrieval_batch_size=128, - shuffle_buffer_size=0, - ) + idx = self._process_idx_for_eval(idx=self.idx_test) + if idx is not None: + dataset = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='predict') return self.model.predict_reconstructed(x=dataset) else: return np.array([]) - def predict_embedding(self, batch_size: int = 128): + def predict_embedding(self, batch_size: int = 128, variational: bool = False): """ return the prediction in the latent space (z_mean for variational models) + :params variational: Whether toreturn the prediction of z, z_mean, z_log_var in the variational latent space. :return: latent space """ - if self.idx_test is None or self.idx_test.any(): - dataset = self._get_dataset( - idx=self.idx_test, - batch_size=batch_size, - mode='predict', - retrieval_batch_size=128, - shuffle_buffer_size=0, - ) - return self.model.predict_embedding(x=dataset, variational=False) - else: - return np.array([]) - - def predict_embedding_variational(self, batch_size: int = 128, max_steps: int = np.inf): - """ - return the prediction of z, z_mean, z_log_var in the variational latent space - - :return: - sample of latent space, mean of latent space, variance of latent space - """ - if self.idx_test is None or self.idx_test: - dataset = self._get_dataset( - idx=self.idx_test, - batch_size=batch_size, - mode='predict', - retrieval_batch_size=128, - shuffle_buffer_size=0, - ) - return self.model.predict_embedding(x=dataset, variational=True) + idx = self._process_idx_for_eval(idx=self.idx_test) + if len(idx) > 0: + dataset = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='predict') + return self.model.predict_embedding(x=dataset, variational=variational) else: return np.array([]) @@ -918,6 +713,7 @@ def compute_gradients_input( abs_gradients: bool = True, per_celltype: bool = False ): + # TODO may need to be adapted to new dataset / generator format if test_data: idx = self.idx_test if self.idx_test is None: @@ -928,14 +724,10 @@ def compute_gradients_input( idx = None n_obs = self.data.X.shape[0] - ds = self._get_dataset( - idx=idx, - batch_size=batch_size, - mode='gradient_method', # to get a tf.GradientTape compatible data set - ) + ds = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='gradient_method') if per_celltype: - cell_to_id = self._get_class_dict(obs_key=self._adata_ids.cellontology_class) + cell_to_id = self._get_class_dict(obs_key=self._adata_ids.cell_type) cell_names = cell_to_id.keys() cell_id = cell_to_id.values() id_to_cell = dict([(key, value) for (key, value) in zip(cell_id, cell_names)]) @@ -1000,6 +792,7 @@ class EstimatorKerasCelltype(EstimatorKeras): """ celltype_universe: CelltypeUniverse + model: Union[BasicModelKerasCelltype, None] def __init__( self, @@ -1026,24 +819,10 @@ def __init__( ) if remove_unlabeled_cells: # Remove cells without type label from store: - if isinstance(self.data, DistributedStoreSingleFeatureSpace): - self.data.subset(attr_key="cellontology_class", excluded_values=[ - self._adata_ids.unknown_celltype_identifier, - self._adata_ids.not_a_cell_celltype_identifier, - None, # TODO: it may be possible to remove this in the future - np.nan, # TODO: it may be possible to remove this in the future - ]) - elif isinstance(self.data, anndata.AnnData): - self.data = self.data[np.where([ - x not in [ - self._adata_ids.unknown_celltype_identifier, - self._adata_ids.not_a_cell_celltype_identifier, - None, # TODO: it may be possible to remove this in the future - np.nan, # TODO: it may be possible to remove this in the future - ] for x in self.data.obs[self._adata_ids.cellontology_class].values - ])[0], :] - else: - assert False + self.data.subset(attr_key="cell_type", excluded_values=[ + self._adata_ids.unknown_metadata_identifier, + self._adata_ids.not_a_cell_celltype_identifier, + ]) assert "cl" in self.topology_container.output.keys(), self.topology_container.output.keys() assert "targets" in self.topology_container.output.keys(), self.topology_container.output.keys() self.max_class_weight = max_class_weight @@ -1101,7 +880,7 @@ def encoder(x) -> np.ndarray: # Encodes unknowns to empty rows. idx = [ leave_maps[y] if y not in [ - self._adata_ids.unknown_celltype_identifier, + self._adata_ids.unknown_metadata_identifier, self._adata_ids.not_a_cell_celltype_identifier, ] else np.array([]) for y in x @@ -1120,6 +899,7 @@ def _get_celltype_out( idx: Union[np.ndarray, None], ): """ + TODO depreceate, carry over weight code to _get_generator Build one hot encoded cell type output tensor and observation-wise weight matrix. :param lookup_ontology: list of ontology names to consider. @@ -1131,7 +911,7 @@ def _get_celltype_out( onehot_encoder = self._one_hot_encoder() y = np.concatenate([ onehot_encoder(z) - for z in self.data.obs[self._adata_ids.cellontology_id].values[idx].tolist() + for z in self.data.obs[self._adata_ids.cell_type + self._adata_ids.onto_id_suffix].values[idx].tolist() ], axis=0) # Distribute aggregated class weight for computation of weights: freq = np.mean(y / np.sum(y, axis=1, keepdims=True), axis=0, keepdims=True) @@ -1143,169 +923,53 @@ def _get_celltype_out( ).flatten() return weights, y - @staticmethod - def _get_output_dim(n_features, n_labels, mode): + def _tf_dataset_kwargs(self, mode): if mode == 'predict': output_types = (tf.float32,) - output_shapes = (tf.TensorShape([n_features]),) + output_shapes = (tf.TensorShape([self.data.n_vars]),) else: output_types = (tf.float32, tf.float32, tf.float32) output_shapes = ( - (tf.TensorShape([n_features])), - tf.TensorShape([n_labels]), + (tf.TensorShape([self.data.n_vars])), + tf.TensorShape([self.ntypes]), tf.TensorShape([]) ) + return {"output_types": output_types, "output_shapes": output_shapes} - return output_types, output_shapes - - def _get_base_generator( - self, - idx: Union[np.ndarray, None], - yield_labels: bool, - weighted: bool, - batch_size: int, - randomized_batch_access: bool, - **kwargs, - ): - """ - Yield a basic generator based on which a tf dataset can be built. - - The signature of this generator can be modified through generator_helper. - - :param generator_helper: Python function that should take (x_sample, y_sample, w_sample) as an input: - - - x_sample is a gene expression vector of a cell - - y_sample is a one-hot encoded label vector of a cell - - w_sample is a weight scalar of a cell - :param idx: Indicies of data set to include in generator. - :param yield_labels: - :param batch_size: Number of observations read from disk in each batched access. - :param randomized_batch_access: Whether to randomize batches during reading (in generator). Lifts necessity of - using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there - is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest - changes in batch composition. - :return: - """ - if idx is None: - idx = np.arange(0, self.data.n_obs) - - # Prepare data reading according to whether anndata is backed or not: - if self.using_store: - if weighted: - raise ValueError("using weights with store is not supported yet") - generator_raw, _ = self.data.generator( - idx=idx, - batch_size=batch_size, - obs_keys=[self._adata_ids.cellontology_id], - return_dense=True, - randomized_batch_access=randomized_batch_access, - ) - if yield_labels: - onehot_encoder = self._one_hot_encoder() - - def generator(): - for z in generator_raw(): - x_sample = z[0] - if isinstance(x_sample, scipy.sparse.csr_matrix): - x_sample = x_sample.todense() - x_sample = np.asarray(x_sample) - if yield_labels: - y_sample = onehot_encoder(z[1][self._adata_ids.cellontology_id].values) - for i in range(x_sample.shape[0]): - if y_sample[i].sum() > 0: - yield x_sample[i], y_sample[i], 1. - else: - for i in range(x_sample.shape[0]): - yield x_sample[i], - n_features = self.data.n_vars - n_samples = self.data.n_obs - else: - if yield_labels: - weights, y = self._get_celltype_out(idx=idx) - if not weighted: - weights = np.ones_like(weights) - x = self.data.X if self.data.isbacked else self._prepare_data_matrix(idx=idx) - is_sparse = isinstance(x, scipy.sparse.spmatrix) - indices = idx if self.data.isbacked else range(x.shape[0]) - n_obs = len(indices) - remainder = n_obs % batch_size - batch_starts_ends = [ - (int(x * batch_size), int(x * batch_size) + batch_size) - for x in np.arange(0, n_obs // batch_size + int(remainder > 0)) - ] - - def generator(): - for s, e in batch_starts_ends: - x_sample = np.asarray(x[indices[s:e], :].todense()) if is_sparse else x[indices[s:e], :] - if yield_labels: - y_sample = y[indices[s:e], :] - w_sample = weights[indices[s:e]] - for i in range(x_sample.shape[0]): - if y_sample[i].sum() > 0: - yield x_sample[i], y_sample[i], w_sample[i] - else: - for i in range(x_sample.shape[0]): - yield x_sample[i], - - n_features = x.shape[1] - n_samples = x.shape[0] - - n_labels = self.celltype_universe.onto_cl.n_leaves - return generator, n_samples, n_features, n_labels - - def _get_dataset( + def _get_generator( self, idx: Union[np.ndarray, None], - batch_size: Union[int, None], mode: str, - shuffle_buffer_size: int = int(1e7), - cache_full: bool = False, + retrieval_batch_size: int, + randomized_batch_access: bool, weighted: bool = False, - retrieval_batch_size: int = 128, - randomized_batch_access: bool = False, - prefetch: Union[int, None] = 1, - ) -> tf.data.Dataset: - """ + **kwargs + ) -> GeneratorSingle: + # Define constants used by map_fn in outer name space so that they are not created for each sample. + if weighted: + raise ValueError("using weights with store is not supported yet") + yield_labels = mode in ["train", "train_val", "eval", "test"] + if yield_labels: + onehot_encoder = self._one_hot_encoder() + + def map_fn(x_sample, obs_sample): + x_sample = np.asarray(x_sample).flatten() + if yield_labels: + y_sample = onehot_encoder(obs_sample[self._adata_ids.cell_type + self._adata_ids.onto_id_suffix].values) + y_sample = y_sample.flatten() + if y_sample.sum() > 0: + output = x_sample, y_sample, 1. + else: + output = None + else: + output = x_sample, + return output - :param idx: - :param batch_size: - :param mode: - :param shuffle_buffer_size: - :param weighted: Whether to use weights. - :param retrieval_batch_size: Number of observations read from disk in each batched access. - :param randomized_batch_access: Whether to randomize batches during reading (in generator). Lifts necessity of - using a shuffle buffer on generator, however, batch composition stays unchanged over epochs unless there - is overhangs in retrieval_batch_size in the raw data files, which often happens and results in modest - changes in batch composition. - :return: - """ - generator, n_samples, n_features, n_labels = self._get_base_generator( - idx=idx, - yield_labels=mode in ['train', 'train_val', 'eval'], - weighted=weighted, - batch_size=retrieval_batch_size, - randomized_batch_access=randomized_batch_access, - ) - output_types, output_shapes = self._get_output_dim(n_features=n_features, n_labels=n_labels, mode=mode) - dataset = tf.data.Dataset.from_generator( - generator=generator, - output_types=output_types, - output_shapes=output_shapes - ) - if cache_full: - dataset = dataset.cache() - if mode == 'train' or mode == 'train_val': - dataset = dataset.repeat() - if shuffle_buffer_size is not None and shuffle_buffer_size > 0: - dataset = dataset.shuffle( - buffer_size=min(n_samples, shuffle_buffer_size), - seed=None, - reshuffle_each_iteration=True) - if prefetch is None: - prefetch = tf.data.AUTOTUNE - dataset = dataset.batch(batch_size, drop_remainder=False).prefetch(prefetch) - - return dataset + g = self.data.generator(idx=idx, retrieval_batch_size=retrieval_batch_size, + obs_keys=[self._adata_ids.cell_type + self._adata_ids.onto_id_suffix], map_fn=map_fn, + return_dense=True, randomized_batch_access=randomized_batch_access, + random_access=False) + return g def _get_loss(self): return LossCrossentropyAgg() @@ -1328,15 +992,9 @@ def predict(self, batch_size: int = 128, max_steps: int = np.inf): :param max_steps: Maximum steps before evaluation round is considered complete. :return: Prediction tensor. """ - idx = self.idx_test - if idx is None or idx.any(): - dataset = self._get_dataset( - idx=idx, - batch_size=batch_size, - mode='predict', - retrieval_batch_size=128, - shuffle_buffer_size=0, - ) + idx = self._process_idx_for_eval(idx=self.idx_test) + if len(idx) > 0: + dataset = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='predict') return self.model.training_model.predict(x=dataset) else: return np.array([]) @@ -1347,13 +1005,9 @@ def ytrue(self, batch_size: int = 128, max_steps: int = np.inf): :return: true labels """ - if self.idx_test is None or self.idx_test.any(): - dataset = self._get_dataset( - idx=self.idx_test, - batch_size=batch_size, - mode='eval', - shuffle_buffer_size=0, - ) + idx = self._process_idx_for_eval(idx=self.idx_test) + if len(idx) > 0: + dataset = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='eval') y_true = [] for _, y, _ in dataset.as_numpy_iterator(): y_true.append(y) @@ -1374,16 +1028,9 @@ def evaluate_any(self, idx, batch_size: int = 128, max_steps: int = np.inf, weig :param weighted: Whether to use class weights in evaluation. :return: Dictionary of metric names and values. """ - if idx is None or idx.any(): # true if the array is not empty or if the passed value is None - idx = np.arange(0, self.data.n_obs) if idx is None else idx - dataset = self._get_dataset( - idx=idx, - batch_size=batch_size, - mode='eval', - weighted=weighted, - retrieval_batch_size=128, - shuffle_buffer_size=0, - ) + idx = self._process_idx_for_eval(idx=idx) + if len(idx) > 0: + dataset = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='eval') results = self.model.training_model.evaluate(x=dataset) return dict(zip(self.model.training_model.metrics_names, results)) else: @@ -1407,19 +1054,16 @@ def compute_gradients_input( test_data: bool = False, abs_gradients: bool = True ): - + # TODO may need to be adapted to new dataset / generator format if test_data: idx = self.idx_test n_obs = len(self.idx_test) else: idx = None - n_obs = self.data.X.shape[0] + n_obs = self.data.n_obs - ds = self._get_dataset( - idx=idx, - batch_size=64, - mode='train_val' # to get a tf.GradientTape compatible data set - ) + # to get a tf.GradientTape compatible data set + ds = self.get_one_time_tf_dataset(idx=idx, batch_size=64, mode='train_val') grads_x = 0 # Loop over sub-selected data set and sum gradients across all selected observations. model = tf.keras.Model( diff --git a/sfaira/models/celltype/__init__.py b/sfaira/models/celltype/__init__.py index 9ada0b648..6dc8e3ae7 100644 --- a/sfaira/models/celltype/__init__.py +++ b/sfaira/models/celltype/__init__.py @@ -1,2 +1,3 @@ +from sfaira.models.celltype.base import BasicModelKerasCelltype from sfaira.models.celltype.marker import CellTypeMarker, CellTypeMarkerVersioned from sfaira.models.celltype.mlp import CellTypeMlp, CellTypeMlpVersioned diff --git a/sfaira/models/celltype/base.py b/sfaira/models/celltype/base.py new file mode 100644 index 000000000..6099dd092 --- /dev/null +++ b/sfaira/models/celltype/base.py @@ -0,0 +1,22 @@ +import abc +try: + import tensorflow as tf +except ImportError: + tf = None +from sfaira.models.base import BasicModelKeras + + +class BasicModelEmbedding: + + @abc.abstractmethod + def predict(self, x, **kwargs): + pass + + +class BasicModelKerasCelltype(BasicModelKeras): + """ + This base class defines model attributes shared across all tf.keras cell type models. + """ + + def predict(self, x, **kwarg): + return self.training_model.predict(x) diff --git a/sfaira/models/celltype/marker.py b/sfaira/models/celltype/marker.py index 2e05e694a..1b3c342ad 100644 --- a/sfaira/models/celltype/marker.py +++ b/sfaira/models/celltype/marker.py @@ -6,7 +6,7 @@ from sfaira.versions.metadata import CelltypeUniverse from sfaira.versions.topologies import TopologyContainer -from sfaira.models.base import BasicModelKeras +from sfaira.models.celltype.base import BasicModelKerasCelltype from sfaira.models.pp_layer import PreprocInput @@ -40,7 +40,7 @@ def call(self, inputs): return tf.nn.sigmoid(x) -class CellTypeMarker(BasicModelKeras): +class CellTypeMarker(BasicModelKerasCelltype): """ Marker gene-based cell type classifier: Learns whether or not each gene exceeds requires threshold and learns cell type assignment as linear combination of these marker gene presence probabilities. diff --git a/sfaira/models/celltype/mlp.py b/sfaira/models/celltype/mlp.py index 92116da6c..f846c9131 100644 --- a/sfaira/models/celltype/mlp.py +++ b/sfaira/models/celltype/mlp.py @@ -7,11 +7,11 @@ from sfaira.versions.metadata import CelltypeUniverse from sfaira.versions.topologies import TopologyContainer -from sfaira.models.base import BasicModelKeras +from sfaira.models.celltype.base import BasicModelKerasCelltype from sfaira.models.pp_layer import PreprocInput -class CellTypeMlp(BasicModelKeras): +class CellTypeMlp(BasicModelKerasCelltype): """ Multi-layer perceptron to predict cell type. diff --git a/sfaira/models/embedding/__init__.py b/sfaira/models/embedding/__init__.py index f206e5be5..c70813e5a 100644 --- a/sfaira/models/embedding/__init__.py +++ b/sfaira/models/embedding/__init__.py @@ -1,3 +1,4 @@ +from sfaira.models.embedding.base import BasicModelKerasEmbedding from sfaira.models.embedding.ae import ModelKerasAe, ModelAeVersioned from sfaira.models.embedding.vae import ModelKerasVae, ModelVaeVersioned from sfaira.models.embedding.linear import ModelKerasLinear, ModelLinearVersioned diff --git a/sfaira/models/embedding/ae.py b/sfaira/models/embedding/ae.py index 08ad9a396..428132830 100644 --- a/sfaira/models/embedding/ae.py +++ b/sfaira/models/embedding/ae.py @@ -8,7 +8,7 @@ from sfaira.models.embedding.output_layers import NegBinOutput, NegBinSharedDispOutput, NegBinConstDispOutput, \ GaussianOutput, GaussianSharedStdOutput, GaussianConstStdOutput from sfaira.versions.topologies import TopologyContainer -from sfaira.models.base import BasicModelKeras +from sfaira.models.embedding.base import BasicModelKerasEmbedding from sfaira.models.pp_layer import PreprocInput @@ -113,7 +113,7 @@ def call(self, x, **kwargs): return x -class ModelKerasAe(BasicModelKeras): +class ModelKerasAe(BasicModelKerasEmbedding): """Combines the encoder and decoder into an end-to-end model for training.""" # Note: Original DCA implementation uses l1_l2 regularisation also on last layer (nb) - missing here # Note: Original DCA implementation uses softplus function instead of exponential as dispersion activation @@ -198,14 +198,6 @@ def __init__( name="autoencoder" ) - def predict_reconstructed(self, x): - return np.split(self.training_model.predict(x), indices_or_sections=2, axis=1)[0] - - def predict_embedding(self, x, variational=False): - if variational: - raise ValueError("Cannot predict variational embedding on AE model.topo") - return self.encoder_model.predict(x) - class ModelAeVersioned(ModelKerasAe): def __init__( diff --git a/sfaira/models/embedding/base.py b/sfaira/models/embedding/base.py new file mode 100644 index 000000000..6b0b96ca2 --- /dev/null +++ b/sfaira/models/embedding/base.py @@ -0,0 +1,32 @@ +import abc +import numpy as np +try: + import tensorflow as tf +except ImportError: + tf = None +from sfaira.models.base import BasicModelKeras + + +class BasicModelEmbedding: + + @abc.abstractmethod + def predict_reconstructed(self, x, **kwargs): + pass + + @abc.abstractmethod + def predict_embedding(self, x, **kwargs): + pass + + +class BasicModelKerasEmbedding(BasicModelKeras, BasicModelEmbedding): + """ + This base class defines model attributes shared across all tf.keras embedding models. + """ + + encoder_model: tf.keras.Model + + def predict_reconstructed(self, x, **kwargs): + return np.split(self.training_model.predict(x), indices_or_sections=2, axis=1)[0] + + def predict_embedding(self, x, **kwargs): + return self.encoder_model.predict(x) diff --git a/sfaira/models/embedding/linear.py b/sfaira/models/embedding/linear.py index 3004006be..cea092bfe 100644 --- a/sfaira/models/embedding/linear.py +++ b/sfaira/models/embedding/linear.py @@ -8,7 +8,7 @@ from sfaira.models.embedding.output_layers import NegBinOutput, NegBinSharedDispOutput, NegBinConstDispOutput, \ GaussianOutput, GaussianSharedStdOutput, GaussianConstStdOutput from sfaira.versions.topologies import TopologyContainer -from sfaira.models.base import BasicModelKeras +from sfaira.models.embedding.base import BasicModelKerasEmbedding from sfaira.models.pp_layer import PreprocInput @@ -39,7 +39,7 @@ def call(self, inputs, **kwargs): return x -class ModelKerasLinear(BasicModelKeras): +class ModelKerasLinear(BasicModelKerasEmbedding): def __init__( self, @@ -95,12 +95,6 @@ def __init__( name="autoencoder" ) - def predict_reconstructed(self, x): - return np.split(self.training_model.predict(x), indices_or_sections=2, axis=1)[0] - - def predict_embedding(self, x, **kwargs): - return self.encoder_model.predict(x) - class ModelLinearVersioned(ModelKerasLinear): def __init__( diff --git a/sfaira/models/embedding/vae.py b/sfaira/models/embedding/vae.py index 47fafa498..59612627d 100644 --- a/sfaira/models/embedding/vae.py +++ b/sfaira/models/embedding/vae.py @@ -8,7 +8,7 @@ from sfaira.models.embedding.output_layers import NegBinOutput, NegBinSharedDispOutput, NegBinConstDispOutput, \ GaussianOutput, GaussianSharedStdOutput, GaussianConstStdOutput from sfaira.versions.topologies import TopologyContainer -from sfaira.models.base import BasicModelKeras +from sfaira.models.embedding.base import BasicModelKerasEmbedding from sfaira.models.pp_layer import PreprocInput @@ -140,7 +140,7 @@ def call(self, inputs, **kwargs): return x -class ModelKerasVae(BasicModelKeras): +class ModelKerasVae(BasicModelKerasEmbedding): def predict_reconstructed(self, x: np.ndarray): return np.split(self.training_model.predict(x)[0], indices_or_sections=2, axis=1)[0] diff --git a/sfaira/models/embedding/vaeiaf.py b/sfaira/models/embedding/vaeiaf.py index 4db1875a6..c63a72214 100644 --- a/sfaira/models/embedding/vaeiaf.py +++ b/sfaira/models/embedding/vaeiaf.py @@ -8,7 +8,7 @@ from sfaira.models.embedding.output_layers import NegBinOutput, NegBinSharedDispOutput, NegBinConstDispOutput, \ GaussianOutput, GaussianSharedStdOutput, GaussianConstStdOutput from sfaira.versions.topologies import TopologyContainer -from sfaira.models.base import BasicModelKeras +from sfaira.models.embedding.base import BasicModelKerasEmbedding from sfaira.models.pp_layer import PreprocInput from sfaira.models.made import MaskingDense @@ -221,7 +221,7 @@ def call(self, inputs, **kwargs): return x -class ModelKerasVaeIAF(BasicModelKeras): +class ModelKerasVaeIAF(BasicModelKerasEmbedding): def __init__( self, @@ -329,9 +329,6 @@ def __init__( name="autoencoder" ) - def predict_reconstructed(self, x): - return np.split(self.training_model.predict(x)[0], indices_or_sections=2, axis=1)[0] - def predict_embedding(self, x, variational=False, return_z0=False): if return_z0 and variational: z_t, z_t_mean, z_0 = self.encoder_model.predict(x) diff --git a/sfaira/models/embedding/vaevamp.py b/sfaira/models/embedding/vaevamp.py index 88062b1fc..52be9b557 100644 --- a/sfaira/models/embedding/vaevamp.py +++ b/sfaira/models/embedding/vaevamp.py @@ -8,7 +8,7 @@ from sfaira.models.embedding.output_layers import NegBinOutput, NegBinSharedDispOutput, NegBinConstDispOutput, \ GaussianOutput, GaussianSharedStdOutput, GaussianConstStdOutput from sfaira.versions.topologies import TopologyContainer -from sfaira.models.base import BasicModelKeras +from sfaira.models.embedding.base import BasicModelKerasEmbedding from sfaira.models.pp_layer import PreprocInput @@ -200,7 +200,7 @@ def call(self, inputs, **kwargs): return (p_z1_mean, p_z1_log_var), (p_z2_mean, p_z2_log_var), out -class ModelKerasVaeVamp(BasicModelKeras): +class ModelKerasVaeVamp(BasicModelKerasEmbedding): def predict_reconstructed(self, x: np.ndarray): return np.split(self.training_model.predict(x)[0], indices_or_sections=2, axis=1)[0] diff --git a/sfaira/train/summaries.py b/sfaira/train/summaries.py index 8d567c614..55c8744a7 100644 --- a/sfaira/train/summaries.py +++ b/sfaira/train/summaries.py @@ -911,7 +911,7 @@ def _fn(yhat, ytrue): store.subset(attr_key="id", values=[k for k in store.indices.keys() if 'cell_ontology_class' in store.adata_by_key[k].obs.columns]) store.subset(attr_key="cellontology_class", excluded_values=[ - store._adata_ids_sfaira.unknown_celltype_identifier, + store._adata_ids_sfaira.unknown_metadata_identifier, store._adata_ids_sfaira.not_a_cell_celltype_identifier, ]) cu = CelltypeUniverse( @@ -1076,7 +1076,7 @@ def plot_best_classwise_scatter( store.subset(attr_key="id", values=[k for k in store.indices.keys() if 'cell_ontology_id' in store.adata_by_key[k].obs.columns]) store.subset(attr_key="cellontology_class", excluded_values=[ - store._adata_ids_sfaira.unknown_celltype_identifier, + store._adata_ids_sfaira.unknown_metadata_identifier, store._adata_ids_sfaira.not_a_cell_celltype_identifier, ]) cu = CelltypeUniverse( @@ -1426,7 +1426,7 @@ def get_gradients_by_celltype( store.subset(attr_key="id", values=[k for k in store.indices.keys() if 'cell_ontology_id' in store.adata_by_key[k].obs.columns]) store.subset(attr_key="cellontology_class", excluded_values=[ - store._adata_ids_sfaira.unknown_celltype_identifier, + store._adata_ids_sfaira.unknown_metadata_identifier, store._adata_ids_sfaira.not_a_cell_celltype_identifier, ]) adatas = store.adata_sliced diff --git a/sfaira/train/train_model.py b/sfaira/train/train_model.py index 4ed5aede9..f0c791055 100644 --- a/sfaira/train/train_model.py +++ b/sfaira/train/train_model.py @@ -6,6 +6,7 @@ from typing import Union from sfaira.consts import AdataIdsSfaira +from sfaira.data.store.base import DistributedStoreBase from sfaira.data import DistributedStoreSingleFeatureSpace, Universe from sfaira.estimators import EstimatorKeras, EstimatorKerasCelltype, EstimatorKerasEmbedding from sfaira.ui import ModelZoo @@ -28,12 +29,14 @@ def __init__( self.data.obs = pd.read_csv(fn_backed_obs) elif isinstance(data, anndata.AnnData): self.data = data + elif isinstance(data, list) and isinstance(data[0], anndata.AnnData): + self.data = data elif isinstance(data, Universe): self.data = data.adata - elif isinstance(data, DistributedStoreSingleFeatureSpace): + elif isinstance(data, DistributedStoreBase): self.data = data else: - raise ValueError(f"did not recongize data of type {type(data)}") + raise ValueError(f"did not recognize data of type {type(data)}") self.zoo = ModelZoo() self._adata_ids = AdataIdsSfaira() @@ -241,6 +244,6 @@ def _save_specific(self, fn: str, **kwargs): with open(fn + "_topology.pickle", "wb") as f: pickle.dump(obj=self.topology_dict, file=f) - cell_counts = obs['cell_ontology_class'].value_counts().to_dict() + cell_counts = obs['cell_type'].value_counts().to_dict() with open(fn + '_celltypes_valuecounts_wholedata.pickle', 'wb') as f: pickle.dump(obj=[cell_counts], file=f) diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml index 3c64351e8..6b8a32dfe 100644 --- a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock1/human_lung_2021_10xtechnology_mock1_001.yaml @@ -23,7 +23,7 @@ dataset_or_observation_wise: bio_sample_obs_key: cell_line: cell_line_obs_key: - development_stage: + development_stage: "50-year-old human stage" development_stage_obs_key: disease: "healthy" disease_obs_key: @@ -44,7 +44,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: observation_wise: - cell_types_original_obs_key: "free_annotation" + cell_type_obs_key: "free_annotation" feature_wise: gene_id_ensembl_var_key: "index" gene_id_symbols_var_key: diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml index 436de0756..d55899535 100644 --- a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock2/mouse_pancreas_2021_10xtechnology_mock2_001.yaml @@ -3,7 +3,7 @@ dataset_structure: sample_fns: dataset_wise: author: - - "mock3" + - "mock2" default_embedding: doi_journal: "no_doi_mock2" doi_preprint: @@ -23,7 +23,7 @@ dataset_or_observation_wise: bio_sample_obs_key: cell_line: cell_line_obs_key: - development_stage: + development_stage: "2 weeks" development_stage_obs_key: disease: "healthy" disease_obs_key: @@ -44,7 +44,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: observation_wise: - cell_types_original_obs_key: "free_annotation" + cell_type_obs_key: "free_annotation" feature_wise: gene_id_ensembl_var_key: "index" gene_id_symbols_var_key: diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml index e2f876bff..f44451498 100644 --- a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock3/human_lung_2021_10xtechnology_mock3_001.yaml @@ -3,7 +3,7 @@ dataset_structure: sample_fns: dataset_wise: author: - - "mock2" + - "mock3" default_embedding: doi_journal: "no_doi_mock3" doi_preprint: @@ -44,7 +44,7 @@ dataset_or_observation_wise: tech_sample: tech_sample_obs_key: observation_wise: - cell_types_original_obs_key: "free_annotation" + cell_type_obs_key: "free_annotation" feature_wise: gene_id_ensembl_var_key: "index" gene_id_symbols_var_key: diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/__init__.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/__init__.py new file mode 100644 index 000000000..b1d5b2c2b --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/__init__.py @@ -0,0 +1 @@ +FILE_PATH = __file__ diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.py new file mode 100644 index 000000000..00c808896 --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.py @@ -0,0 +1,12 @@ +import anndata + +from sfaira.unit_tests.data_for_tests.loaders.consts import ASSEMBLY_HUMAN +from sfaira.unit_tests.data_for_tests.loaders.utils import _create_adata + + +def load(data_dir, sample_fn, **kwargs) -> anndata.AnnData: + ncells = 20 + ngenes = 60 + adata = _create_adata(celltypes=[], ncells=ncells, ngenes=ngenes, + assembly=ASSEMBLY_HUMAN) + return adata diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.yaml b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.yaml new file mode 100644 index 000000000..a06ea951e --- /dev/null +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/dno_doi_mock4/human_lung_2021_10xtechnology_mock4_001.yaml @@ -0,0 +1,52 @@ +dataset_structure: + dataset_index: 1 + sample_fns: +dataset_wise: + author: + - "mock4" + default_embedding: + doi_journal: "no_doi_mock4" + doi_preprint: + download_url_data: "" + download_url_meta: "" + normalization: "raw" + primary_data: + year: 2021 +dataset_or_observation_wise: + assay_sc: + assay_sc_obs_key: + assay_differentiation: + assay_differentiation_obs_key: + assay_type_differentiation: + assay_type_differentiation_obs_key: + bio_sample: + bio_sample_obs_key: + cell_line: + cell_line_obs_key: + development_stage: + development_stage_obs_key: + disease: + disease_obs_key: + ethnicity: + ethnicity_obs_key: + individual: + individual_obs_key: + organ: "lung" + organ_obs_key: + organism: "human" + organism_obs_key: + sample_source: "primary_tissue" + sample_source_obs_key: + sex: + sex_obs_key: + state_exact: + state_exact_obs_key: + tech_sample: + tech_sample_obs_key: +observation_wise: + cell_type_obs_key: +feature_wise: + gene_id_ensembl_var_key: "index" + gene_id_symbols_var_key: +meta: + version: "1.0" diff --git a/sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py b/sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py index 46b508203..c55eaeb68 100644 --- a/sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py +++ b/sfaira/unit_tests/data_for_tests/loaders/loaders/super_group.py @@ -51,7 +51,7 @@ def __init__(self): sample_fns=None, yaml_path=fn_yaml, ) - x.load_ontology_class_map(fn=os.path.join(path_module, file_module + ".tsv")) + x.read_ontology_class_map(fn=os.path.join(path_module, file_module + ".tsv")) datasets.append(x) else: warn(f"DatasetGroupDirectoryOriented was None for {f}") diff --git a/sfaira/unit_tests/data_for_tests/loaders/utils.py b/sfaira/unit_tests/data_for_tests/loaders/utils.py index a94ddf5d4..fcd1a8037 100644 --- a/sfaira/unit_tests/data_for_tests/loaders/utils.py +++ b/sfaira/unit_tests/data_for_tests/loaders/utils.py @@ -7,7 +7,7 @@ from sfaira.versions.genomes import GenomeContainer from sfaira.unit_tests.directories import DIR_DATA_LOADERS_CACHE, DIR_DATA_LOADERS_STORE_DAO, \ - DIR_DATA_LOADERS_STORE_H5AD + DIR_DATA_LOADERS_STORE_H5AD, save_delete from .consts import ASSEMBLY_HUMAN, ASSEMBLY_MOUSE from .loaders import DatasetSuperGroupMock @@ -23,9 +23,9 @@ def _create_adata(celltypes, ncells, ngenes, assembly) -> anndata.AnnData: genes = gc.ensembl[:ngenes] x = scipy.sparse.csc_matrix(np.random.randint(low=0, high=100, size=(ncells, ngenes))) var = pd.DataFrame(index=genes) - obs = pd.DataFrame({ - "free_annotation": [celltypes[i] for i in np.random.choice(a=[0, 1], size=ncells, replace=True)] - }, index=["cell_" + str(i) for i in range(ncells)]) + obs = pd.DataFrame({}, index=["cell_" + str(i) for i in range(ncells)]) + if len(celltypes) > 0: + obs["free_annotation"] = [celltypes[i] for i in np.random.choice(len(celltypes), size=ncells, replace=True)] adata = anndata.AnnData(X=x, obs=obs, var=var) return adata @@ -71,11 +71,15 @@ def prepare_store(store_format: str, rewrite: bool = False, rewrite_store: bool else: compression_kwargs = {} if store_format == "dao": - anticipated_fn = os.path.join(dir_store_formatted, k) + anticipated_fn = os.path.join(dir_store_formatted, ds.doi_cleaned_id) elif store_format == "h5ad": - anticipated_fn = os.path.join(dir_store_formatted, k + ".h5ad") + anticipated_fn = os.path.join(dir_store_formatted, ds.doi_cleaned_id + ".h5ad") else: assert False + if rewrite_store and os.path.exists(anticipated_fn): + # Can't write if h5ad already exists. + # Delete store to writing if forced. + save_delete(anticipated_fn) # Only rewrite if necessary if rewrite_store or not os.path.exists(anticipated_fn): ds = _load_script(dsg=ds, rewrite=rewrite, match_to_reference=MATCH_TO_REFERENCE) diff --git a/sfaira/unit_tests/directories.py b/sfaira/unit_tests/directories.py index f2c457c0f..b1a202ec5 100644 --- a/sfaira/unit_tests/directories.py +++ b/sfaira/unit_tests/directories.py @@ -3,6 +3,7 @@ """ import os +import shutil DIR_TEMP = os.path.join(os.path.dirname(__file__), "temp") @@ -12,3 +13,12 @@ DIR_DATA_LOADERS_STORE_H5AD = os.path.join(_DIR_DATA_LOADERS, "store_h5ad") _DIR_DATA_DATABASES = os.path.join(DIR_TEMP, "databases") DIR_DATA_DATABASES_CACHE = os.path.join(_DIR_DATA_DATABASES, "cache") +DIR_DATABASE_STORE_DAO = os.path.join(_DIR_DATA_DATABASES, "store_dao") + + +def save_delete(fn): + assert str(fn).startswith(DIR_TEMP), f"tried to delete outside of temp directory {fn}" + if os.path.isdir(fn): + shutil.rmtree(fn) + else: + os.remove(fn) diff --git a/sfaira/unit_tests/tests_by_submodule/data/databases/__init__.py b/sfaira/unit_tests/tests_by_submodule/data/databases/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sfaira/unit_tests/tests_by_submodule/data/databases/test_database_intput.py b/sfaira/unit_tests/tests_by_submodule/data/databases/test_database_intput.py new file mode 100644 index 000000000..c9d44c95f --- /dev/null +++ b/sfaira/unit_tests/tests_by_submodule/data/databases/test_database_intput.py @@ -0,0 +1,65 @@ +import os +import pytest +from typing import List + +from sfaira.consts import AdataIdsSfaira +from sfaira.data.store.io_dao import read_dao +from sfaira.unit_tests.data_for_tests.databases.utils import prepare_dsg_database +from sfaira.unit_tests.data_for_tests.databases.consts import CELLXGENE_DATASET_ID +from sfaira.unit_tests.data_for_tests.loaders import ASSEMBLY_HUMAN, ASSEMBLY_MOUSE +from sfaira.unit_tests.directories import DIR_DATABASE_STORE_DAO + + +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ]) +@pytest.mark.parametrize("match_to_reference", [{"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, ]) +@pytest.mark.parametrize("subset_genes_to_type", [None, "protein_coding", ]) +def test_streamline_features(database: str, subset_args: List[str], match_to_reference: dict, + subset_genes_to_type: str): + dsg = prepare_dsg_database(database=database) + dsg.subset(key=subset_args[0], values=subset_args[1]) + dsg.load() + dsg.streamline_features(match_to_reference=match_to_reference, subset_genes_to_type=subset_genes_to_type) + + +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ]) +@pytest.mark.parametrize("format", ["sfaira", ]) +def test_streamline_metadata(database: str, subset_args: List[str], format: str): + dsg = prepare_dsg_database(database=database) + dsg.subset(key=subset_args[0], values=subset_args[1]) + dsg.load() + dsg.streamline_features(match_to_reference={"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, + subset_genes_to_type="protein_coding") + dsg.streamline_metadata(schema=format) + adata = dsg.datasets[subset_args[1]].adata + ids = AdataIdsSfaira() + assert "CL:0000128" in adata.obs[ids.cell_type + ids.onto_id_suffix].values + assert "oligodendrocyte" in adata.obs[ids.cell_type].values + assert "HsapDv:0000087" in adata.obs[ids.development_stage + ids.onto_id_suffix].values + assert "human adult stage" in adata.obs[ids.development_stage].values + assert "UBERON:0000956" in adata.obs[ids.organ + ids.onto_id_suffix].values + assert "cerebral cortex" in adata.obs[ids.organ].values + + +@pytest.mark.parametrize("store", ["dao", ]) +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ]) +def test_output_to_store(store: str, database: str, subset_args: List[str]): + dsg = prepare_dsg_database(database=database) + dsg.subset(key=subset_args[0], values=subset_args[1]) + dsg.load() + dsg.streamline_features(match_to_reference={"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, + subset_genes_to_type="protein_coding") + dsg.streamline_metadata(schema="sfaira", clean_obs=True, clean_uns=True, clean_var=True, clean_obs_names=True, + keep_id_obs=True, keep_orginal_obs=False, keep_symbol_obs=True) + dsg.write_distributed_store(dir_cache=DIR_DATABASE_STORE_DAO, store_format=store, dense=True) + fn_store = os.path.join(DIR_DATABASE_STORE_DAO, subset_args[1]) + adata = read_dao(store=fn_store) + ids = AdataIdsSfaira() + assert "CL:0000128" in adata.obs[ids.cell_type + ids.onto_id_suffix].values + assert "oligodendrocyte" in adata.obs[ids.cell_type].values + assert "HsapDv:0000087" in adata.obs[ids.development_stage + ids.onto_id_suffix].values + assert "human adult stage" in adata.obs[ids.development_stage].values + assert "UBERON:0000956" in adata.obs[ids.organ + ids.onto_id_suffix].values + assert "cerebral cortex" in adata.obs[ids.organ].values diff --git a/sfaira/unit_tests/tests_by_submodule/data/databases/test_databases_basic.py b/sfaira/unit_tests/tests_by_submodule/data/databases/test_databases_basic.py new file mode 100644 index 000000000..503577c6d --- /dev/null +++ b/sfaira/unit_tests/tests_by_submodule/data/databases/test_databases_basic.py @@ -0,0 +1,35 @@ +import os +import pytest +import shutil +from typing import List + +from sfaira.unit_tests.directories import DIR_DATA_DATABASES_CACHE +from sfaira.unit_tests.data_for_tests.databases.utils import prepare_dsg_database +from sfaira.unit_tests.data_for_tests.databases.consts import CELLXGENE_DATASET_ID + + +# Execute this one first so that data sets are only downloaded once. Named test_a for this reason. +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [None, ["id", CELLXGENE_DATASET_ID], ]) +def test_a_dsgs_download(database: str, subset_args: List[str]): + """ + Tests if downloading of data base entries works. + + Warning, deletes entire database unit test cache. + """ + if os.path.exists(DIR_DATA_DATABASES_CACHE): + shutil.rmtree(DIR_DATA_DATABASES_CACHE) + dsg = prepare_dsg_database(database=database, download=False) + if subset_args is not None: + dsg.subset(key=subset_args[0], values=subset_args[1]) + dsg.download() + + +@pytest.mark.parametrize("database", ["cellxgene", ]) +@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ["organism", "human"], ]) +def test_dsgs_subset(database: str, subset_args: List[str]): + """ + Tests if subsetting results only in datasets of the desired characteristics. + """ + dsg = prepare_dsg_database(database=database) + dsg.subset(key=subset_args[0], values=subset_args[1]) diff --git a/sfaira/unit_tests/tests_by_submodule/data/dataset/__init__.py b/sfaira/unit_tests/tests_by_submodule/data/dataset/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sfaira/unit_tests/tests_by_submodule/data/test_dataset.py b/sfaira/unit_tests/tests_by_submodule/data/dataset/test_dataset.py similarity index 74% rename from sfaira/unit_tests/tests_by_submodule/data/test_dataset.py rename to sfaira/unit_tests/tests_by_submodule/data/dataset/test_dataset.py index c7e063e0c..0a2fd2ac6 100644 --- a/sfaira/unit_tests/tests_by_submodule/data/test_dataset.py +++ b/sfaira/unit_tests/tests_by_submodule/data/dataset/test_dataset.py @@ -13,6 +13,21 @@ def test_dsgs_instantiate(): _ = Universe(data_path=DIR_DATA_LOADERS_CACHE, meta_path=DIR_DATA_LOADERS_CACHE, cache_path=DIR_DATA_LOADERS_CACHE) +def test_dsgs_crossref(): + """ + Tests if crossref attributes can be retrieved for all data loader entries with DOI journal defined. + Attributes tested: + - title + """ + universe = Universe(data_path=DIR_DATA_LOADERS_CACHE, meta_path=DIR_DATA_LOADERS_CACHE, + cache_path=DIR_DATA_LOADERS_CACHE) + for k, v in universe.datasets.items(): + title = v.title + if title is None: + if v.doi_journal is not None and "no_doi" not in v.doi_journal: + raise ValueError(f"did not retrieve title for data set {k} with DOI: {v.doi_journal}.") + + @pytest.mark.parametrize("organ", ["intestine", "ileum"]) def test_dsgs_subset_dataset_wise(organ: str): """ @@ -70,26 +85,8 @@ def test_dsgs_subset_cell_wise(organ: str, celltype: str): for k, v in x.datasets.items(): assert v.organism == "mouse", v.id assert v.ontology_container_sfaira.organ.is_a(query=v.organ, reference=organ), v.organ - for y in np.unique(v.adata.obs[v._adata_ids.cellontology_class].values): - assert v.ontology_container_sfaira.cellontology_class.is_a(query=y, reference=celltype), y - - -@pytest.mark.parametrize("out_format", ["sfaira", "cellxgene"]) -@pytest.mark.parametrize("uns_to_obs", [True, False]) -@pytest.mark.parametrize("clean_obs", [True, False]) -@pytest.mark.parametrize("clean_var", [True, False]) -@pytest.mark.parametrize("clean_uns", [True, False]) -@pytest.mark.parametrize("clean_obs_names", [True, False]) -def test_dsgs_streamline_metadata(out_format: str, uns_to_obs: bool, clean_obs: bool, clean_var: bool, clean_uns: bool, - clean_obs_names: bool): - ds = prepare_dsg(load=False) - ds.subset(key="organism", values=["mouse"]) - ds.subset(key="organ", values=["lung"]) - ds.load() - ds.streamline_features(remove_gene_version=False, match_to_reference=ASSEMBLY_MOUSE, - subset_genes_to_type=None) - ds.streamline_metadata(schema=out_format, clean_obs=clean_obs, clean_var=clean_var, - clean_uns=clean_uns, clean_obs_names=clean_obs_names) + for y in np.unique(v.adata.obs[v._adata_ids.cell_type].values): + assert v.ontology_container_sfaira.cell_type.is_a(query=y, reference=celltype), y @pytest.mark.parametrize("match_to_reference", ["Mus_musculus.GRCm38.102", {"mouse": ASSEMBLY_MOUSE}]) diff --git a/sfaira/unit_tests/tests_by_submodule/data/dataset/test_meta_data_streamlining.py b/sfaira/unit_tests/tests_by_submodule/data/dataset/test_meta_data_streamlining.py new file mode 100644 index 000000000..2537de03f --- /dev/null +++ b/sfaira/unit_tests/tests_by_submodule/data/dataset/test_meta_data_streamlining.py @@ -0,0 +1,55 @@ +from cellxgene_schema.validate import validate_adata +import pytest + +from sfaira.unit_tests.data_for_tests.loaders import ASSEMBLY_HUMAN, ASSEMBLY_MOUSE, prepare_dsg + + +@pytest.mark.parametrize("out_format", ["sfaira", "cellxgene"]) +@pytest.mark.parametrize("clean_obs", [True, False]) +@pytest.mark.parametrize("clean_var", [True, False]) +@pytest.mark.parametrize("clean_uns", [True, False]) +@pytest.mark.parametrize("clean_obs_names", [True, False]) +@pytest.mark.parametrize("keep_id_obs", [True]) +@pytest.mark.parametrize("keep_orginal_obs", [False]) +@pytest.mark.parametrize("keep_symbol_obs", [True]) +def test_dsgs_streamline_metadata(out_format: str, clean_obs: bool, clean_var: bool, clean_uns: bool, + clean_obs_names: bool, keep_id_obs: bool, keep_orginal_obs: bool, + keep_symbol_obs: bool): + ds = prepare_dsg(load=False) + ds.subset(key="organism", values=["human"]) + ds.subset(key="organ", values=["lung"]) + if out_format == "cellxgene": + # Other data data sets do not have complete enough annotation + ds.subset(key="doi_journal", values=["no_doi_mock1", "no_doi_mock3"]) + ds.load() + ds.streamline_features(remove_gene_version=False, match_to_reference=ASSEMBLY_MOUSE, + subset_genes_to_type=None) + ds.streamline_metadata(schema=out_format, clean_obs=clean_obs, clean_var=clean_var, + clean_uns=clean_uns, clean_obs_names=clean_obs_names, + keep_id_obs=keep_id_obs, keep_orginal_obs=keep_orginal_obs, keep_symbol_obs=keep_symbol_obs) + + +@pytest.mark.parametrize("schema_version", ["1_1_0"]) +@pytest.mark.parametrize("organism", ["human", "mouse"]) +def test_cellxgene_export(schema_version: str, organism: str): + """ + + This test can be extended by future versions. + """ + ds = prepare_dsg(load=False) + if organism == "human": + ds.subset(key="doi_journal", values=["no_doi_mock1"]) + else: + ds.subset(key="doi_journal", values=["no_doi_mock2"]) + ds.load() + ds.streamline_features(remove_gene_version=False, + match_to_reference={"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, + subset_genes_to_type=None) + ds.streamline_metadata(schema="cellxgene:" + schema_version, clean_obs=False, clean_var=True, + clean_uns=True, clean_obs_names=False, + keep_id_obs=True, keep_orginal_obs=False, keep_symbol_obs=True) + counter = 0 + for ds in ds.datasets.values(): + validate_adata(adata=ds.adata, shallow=False) + counter += 1 + assert counter > 0, "no data sets to test" diff --git a/sfaira/unit_tests/tests_by_submodule/data/test_databases.py b/sfaira/unit_tests/tests_by_submodule/data/test_databases.py deleted file mode 100644 index e4cbd32a1..000000000 --- a/sfaira/unit_tests/tests_by_submodule/data/test_databases.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -import pytest -import shutil -from typing import List - -from sfaira.unit_tests.directories import DIR_DATA_DATABASES_CACHE -from sfaira.unit_tests.data_for_tests.databases.utils import prepare_dsg_database -from sfaira.unit_tests.data_for_tests.databases.consts import CELLXGENE_DATASET_ID -from sfaira.unit_tests.data_for_tests.loaders import ASSEMBLY_HUMAN, ASSEMBLY_MOUSE - - -# Execute this one first so that data sets are only downloaded once. Named test_a for this reason. -@pytest.mark.parametrize("database", ["cellxgene", ]) -@pytest.mark.parametrize("subset_args", [None, ["id", CELLXGENE_DATASET_ID], ]) -def test_a_dsgs_download(database: str, subset_args: List[str]): - """ - Tests if downloading of data base entries works. - - Warning, deletes entire database unit test cache. - """ - if os.path.exists(DIR_DATA_DATABASES_CACHE): - shutil.rmtree(DIR_DATA_DATABASES_CACHE) - dsg = prepare_dsg_database(database=database, download=False) - if subset_args is not None: - dsg.subset(key=subset_args[0], values=subset_args[1]) - dsg.download() - - -@pytest.mark.parametrize("database", ["cellxgene", ]) -@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ["organism", "human"], ]) -def test_dsgs_subset(database: str, subset_args: List[str]): - """ - Tests if subsetting results only in datasets of the desired characteristics. - """ - dsg = prepare_dsg_database(database=database) - dsg.subset(key=subset_args[0], values=subset_args[1]) - - -@pytest.mark.parametrize("database", ["cellxgene", ]) -@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ]) -@pytest.mark.parametrize("match_to_reference", [None, {"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, ]) -def test_dsgs_adata(database: str, subset_args: List[str], match_to_reference: dict): - dsg = prepare_dsg_database(database=database) - dsg.subset(key=subset_args[0], values=subset_args[1]) - dsg.load() - if match_to_reference is not None: - dsg.streamline_features(remove_gene_version=True, match_to_reference=match_to_reference) - dsg.streamline_metadata(schema="sfaira", clean_obs=True, clean_var=True, clean_uns=True, clean_obs_names=True) - _ = dsg.adata - - -@pytest.mark.parametrize("database", ["cellxgene", ]) -@pytest.mark.parametrize("subset_args", [["id", CELLXGENE_DATASET_ID], ]) -@pytest.mark.parametrize("match_to_reference", [{"human": ASSEMBLY_HUMAN, "mouse": ASSEMBLY_MOUSE}, ]) -@pytest.mark.parametrize("subset_genes_to_type", [None, "protein_coding", ]) -def test_dsgs_streamline_features(database: str, subset_args: List[str], match_to_reference: dict, - subset_genes_to_type: str): - dsg = prepare_dsg_database(database=database) - dsg.subset(key=subset_args[0], values=subset_args[1]) - dsg.load() - dsg.streamline_features(match_to_reference=match_to_reference, subset_genes_to_type=subset_genes_to_type) diff --git a/sfaira/unit_tests/tests_by_submodule/data/test_store.py b/sfaira/unit_tests/tests_by_submodule/data/test_store.py index a095669ca..1a380f9a7 100644 --- a/sfaira/unit_tests/tests_by_submodule/data/test_store.py +++ b/sfaira/unit_tests/tests_by_submodule/data/test_store.py @@ -13,21 +13,63 @@ from sfaira.unit_tests.data_for_tests.loaders import ASSEMBLY_MOUSE, prepare_dsg, prepare_store +def _get_single_store(store_format: str): + store_path = prepare_store(store_format=store_format) + stores = load_store(cache_path=store_path, store_format=store_format) + stores.subset(attr_key="organism", values=["mouse"]) + store = stores.stores["mouse"] + return store + + @pytest.mark.parametrize("store_format", ["h5ad", "dao"]) def test_fatal(store_format: str): """ Test if basic methods abort. """ store_path = prepare_store(store_format=store_format) - store = load_store(cache_path=store_path, store_format=store_format) - store.subset(attr_key="organism", values=["mouse"]) - _ = store.n_obs - _ = store.n_vars - _ = store.var_names - _ = store.shape - _ = store.obs - _ = store.stores["mouse"].indices - _ = store.genome_containers + stores = load_store(cache_path=store_path, store_format=store_format) + stores.subset(attr_key="organism", values=["mouse"]) + store = stores.stores["mouse"] + # Test both single and multi-store: + for x in [store, stores]: + _ = x.n_obs + _ = x.n_vars + _ = x.var_names + _ = x.shape + _ = x.obs + _ = x.indices + _ = x.genome_container + + +@pytest.mark.parametrize("store_format", ["h5ad", "dao"]) +@pytest.mark.parametrize("as_sparse", [True, False]) +def test_x_slice(store_format: str, as_sparse: bool): + """ + Test if basic methods abort. + """ + store = _get_single_store(store_format=store_format) + data = store.X_slice(idx=np.arange(0, 5), as_sparse=as_sparse) + assert data.shape[0] == 5 + if as_sparse: + assert isinstance(data, scipy.sparse.csr_matrix) + else: + assert isinstance(data, np.ndarray) + + +@pytest.mark.parametrize("store_format", ["h5ad", "dao"]) +@pytest.mark.parametrize("as_sparse", [True, False]) +def test_adata_slice(store_format: str, as_sparse: bool): + """ + Test if basic methods abort. + """ + store = _get_single_store(store_format=store_format) + data = store.adata_slice(idx=np.arange(0, 5), as_sparse=as_sparse) + assert data.shape[0] == 5 + assert isinstance(data, anndata.AnnData) + if as_sparse: + assert isinstance(data.X, scipy.sparse.csr_matrix) + else: + assert isinstance(data.X, np.ndarray) @pytest.mark.parametrize("store_format", ["h5ad", "dao"]) @@ -37,7 +79,7 @@ def test_data(store_format: str): """ # Run standard streamlining workflow on dsg and compare to object relayed via store. # Prepare dsg. - dsg = prepare_dsg(rewrite=False, load=True) + dsg = prepare_dsg(load=True) # Prepare store. # Rewriting store to avoid mismatch of randomly generated data in cache and store. store_path = prepare_store(store_format=store_format, rewrite=False, rewrite_store=True) @@ -113,8 +155,8 @@ def test_config(store_format: str): @pytest.mark.parametrize("store_format", ["h5ad", "dao"]) @pytest.mark.parametrize("idx", [np.arange(1, 10), np.concatenate([np.arange(30, 50), np.array([1, 4, 98])])]) -@pytest.mark.parametrize("batch_size", [1, 7]) -@pytest.mark.parametrize("obs_keys", [["cell_ontology_class"]]) +@pytest.mark.parametrize("batch_size", [1, ]) +@pytest.mark.parametrize("obs_keys", [["cell_type"]]) @pytest.mark.parametrize("randomized_batch_access", [True, False]) def test_generator_shapes(store_format: str, idx, batch_size: int, obs_keys: List[str], randomized_batch_access: bool): """ @@ -125,13 +167,14 @@ def test_generator_shapes(store_format: str, idx, batch_size: int, obs_keys: Lis store.subset(attr_key="organism", values=["mouse"]) gc = GenomeContainer(assembly=ASSEMBLY_MOUSE) gc.subset(**{"biotype": "protein_coding"}) - store.genome_containers = gc - g, _ = store.generator( + store.genome_container = gc + g = store.generator( idx={"mouse": idx}, batch_size=batch_size, obs_keys=obs_keys, randomized_batch_access=randomized_batch_access, ) + g = g.iterator nobs = len(idx) if idx is not None else store.n_obs batch_sizes = [] x = None @@ -140,6 +183,10 @@ def test_generator_shapes(store_format: str, idx, batch_size: int, obs_keys: Lis for i, z in enumerate(g()): counter += 1 x_i, obs_i = z + if len(x_i.shape) == 1: + # x is flattened if batch size is 1: + assert batch_size == 1 + x_i = np.expand_dims(x_i, axis=0) assert x_i.shape[0] == obs_i.shape[0] if i == 0: x = x_i diff --git a/sfaira/unit_tests/tests_by_submodule/estimators/__init__.py b/sfaira/unit_tests/tests_by_submodule/estimators/__init__.py index 680468bc4..f9faa4fa3 100644 --- a/sfaira/unit_tests/tests_by_submodule/estimators/__init__.py +++ b/sfaira/unit_tests/tests_by_submodule/estimators/__init__.py @@ -1 +1 @@ -from .test_estimator import TARGETS, TestHelperEstimatorBase +from .test_estimator import TARGETS, HelperEstimatorBase diff --git a/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py b/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py index 1aba8f02d..3dce30567 100644 --- a/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py +++ b/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py @@ -7,13 +7,13 @@ from typing import Union from sfaira.consts import AdataIdsSfaira, CACHE_DIR -from sfaira.data import DistributedStoreSingleFeatureSpace, load_store +from sfaira.data import DistributedStoreSingleFeatureSpace, DistributedStoreMultipleFeatureSpaceBase, load_store from sfaira.estimators import EstimatorKeras, EstimatorKerasCelltype, EstimatorKerasEmbedding from sfaira.versions.genomes.genomes import CustomFeatureContainer from sfaira.versions.metadata import OntologyOboCustom from sfaira.versions.topologies import TopologyContainer -from sfaira.unit_tests.data_for_tests.loaders.consts import CELLTYPES +from sfaira.unit_tests.data_for_tests.loaders.consts import CELLTYPES, CL_VERSION from sfaira.unit_tests.data_for_tests.loaders.utils import prepare_dsg, prepare_store from sfaira.unit_tests.directories import DIR_TEMP @@ -55,7 +55,7 @@ "genes": None, }, "output": { - "cl": "v2021-02-01", + "cl": CL_VERSION.split("_")[0], "targets": TARGET_UNIVERSE }, "hyper_parameters": { @@ -65,24 +65,26 @@ } -class TestHelperEstimatorBase: +class HelperEstimatorBase: adata_ids: AdataIdsSfaira - data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace] + data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace, DistributedStoreMultipleFeatureSpaceBase] tc: TopologyContainer def load_adata(self, organism="human", organ=None): - dsg = prepare_dsg() + dsg = prepare_dsg(load=True) + dsg.subset(key="doi_journal", values=["no_doi_mock1", "no_doi_mock2", "no_doi_mock3"]) if organism is not None: dsg.subset(key="organism", values=organism) if organ is not None: dsg.subset(key="organ", values=organ) self.adata_ids = dsg.dataset_groups[0]._adata_ids - self.data = dsg.adata + self.data = dsg.adata_ls def load_store(self, organism="human", organ=None): store_path = prepare_store(store_format="dao") store = load_store(cache_path=store_path, store_format="dao") + store.subset(attr_key="doi_journal", values=["no_doi_mock1", "no_doi_mock2", "no_doi_mock3"]) if organism is not None: store.subset(attr_key="organism", values=organism) if organ is not None: @@ -90,8 +92,17 @@ def load_store(self, organism="human", organ=None): self.adata_ids = store._adata_ids_sfaira self.data = store.stores[organism] + def load_multistore(self): + store_path = prepare_store(store_format="dao") + store = load_store(cache_path=store_path, store_format="dao") + store.subset(attr_key="doi_journal", values=["no_doi_mock1", "no_doi_mock2", "no_doi_mock3"]) + self.adata_ids = store._adata_ids_sfaira + assert "mouse" in store.stores.keys(), store.stores.keys() + assert "human" in store.stores.keys(), store.stores.keys() + self.data = store + -class TestHelperEstimatorKeras(TestHelperEstimatorBase): +class HelperEstimatorKeras(HelperEstimatorBase): data: Union[anndata.AnnData, DistributedStoreSingleFeatureSpace] estimator: Union[EstimatorKeras] @@ -132,7 +143,7 @@ def estimator_train(self, test_split, randomized_batch_access): ) @abc.abstractmethod - def basic_estimator_test(self, test_split): + def basic_estimator_test(self): pass def load_estimator(self, model_type, data_type, feature_space, test_split, organism="human"): @@ -148,10 +159,10 @@ def fatal_estimator_test(self, model_type, data_type, test_split=0.1, feature_sp self.load_estimator(model_type=model_type, data_type=data_type, feature_space=feature_space, test_split=test_split) self.estimator_train(test_split=test_split, randomized_batch_access=False) - self.basic_estimator_test(test_split=test_split) + self.basic_estimator_test() -class HelperEstimatorKerasEmbedding(TestHelperEstimatorKeras): +class HelperEstimatorKerasEmbedding(HelperEstimatorKeras): estimator: EstimatorKerasEmbedding model_type: str @@ -187,7 +198,7 @@ def init_estimator(self, test_split): self.estimator.init_model() self.estimator.split_train_val_test(test_split=test_split, val_split=0.1) - def basic_estimator_test(self, test_split=0.1): + def basic_estimator_test(self): _ = self.estimator.evaluate() prediction_output = self.estimator.predict() prediction_embed = self.estimator.predict_embedding() @@ -206,7 +217,7 @@ def basic_estimator_test(self, test_split=0.1): assert np.allclose(prediction_embed, new_prediction_embed, rtol=1e-6, atol=1e-6) -class TestHelperEstimatorKerasCelltype(TestHelperEstimatorKeras): +class TestHelperEstimatorKerasCelltype(HelperEstimatorKeras): estimator: EstimatorKerasCelltype nleaves: int @@ -233,18 +244,18 @@ def init_estimator(self, test_split): model_dir=DIR_TEMP, cache_path=DIR_TEMP, model_id="testid", - model_topology=tc + model_topology=tc, ) leaves = self.estimator.celltype_universe.onto_cl.get_effective_leaves( - x=[x for x in self.data.obs[self.adata_ids.cellontology_class].values - if x != self.adata_ids.unknown_celltype_identifier] + x=[x for x in self.estimator.data.obs[self.adata_ids.cell_type].values + if x != self.adata_ids.unknown_metadata_identifier] ) self.nleaves = len(leaves) self.estimator.celltype_universe.onto_cl.leaves = leaves self.estimator.init_model() self.estimator.split_train_val_test(test_split=test_split, val_split=0.1) - def basic_estimator_test(self, test_split=0.1): + def basic_estimator_test(self): _ = self.estimator.evaluate() prediction_output = self.estimator.predict() assert prediction_output.shape[1] == self.nleaves, prediction_output.shape @@ -268,34 +279,38 @@ def init_obo_custom(self) -> OntologyOboCustom: return OntologyOboCustom(obo=os.path.join(os.path.dirname(__file__), "custom.obo")) def init_genome_custom(self, n_features) -> CustomFeatureContainer: - return CustomFeatureContainer(genome_tab=pd.DataFrame({ - "gene_name": ["dim_" + str(i) for i in range(n_features)], - "gene_id": ["dim_" + str(i) for i in range(n_features)], - "gene_biotype": ["embedding" for _ in range(n_features)], - })) + return CustomFeatureContainer( + genome_tab=pd.DataFrame({ + "gene_name": ["dim_" + str(i) for i in range(n_features)], + "gene_id": ["dim_" + str(i) for i in range(n_features)], + "gene_biotype": ["embedding" for _ in range(n_features)], + }), + organism="homo_sapiens", + ) def load_adata(self, organism="human", organ=None): - dsg = prepare_dsg(load=False) + dsg = prepare_dsg(load=True) + dsg.subset(key="doi_journal", values=["no_doi_mock1", "no_doi_mock3", "no_doi_mock3"]) if organism is not None: dsg.subset(key="organism", values=organism) if organ is not None: dsg.subset(key="organ", values=organ) self.adata_ids = dsg.dataset_groups[0]._adata_ids # Use mock data loading to generate base line object: - dsg.load() - self.data = dsg.datasets[list(dsg.datasets.keys())[0]].adata + self.data = dsg.adata # - Subset to target feature space size: self.data = self.data[:, :self.tc.gc.n_var].copy() # - Add in custom cell types: - self.data.obs[self.adata_ids.cellontology_class] = [ + self.data.obs[self.adata_ids.cell_type] = [ self.custom_types[np.random.randint(0, len(self.custom_types))] for _ in range(self.data.n_obs) ] - self.data.obs[self.adata_ids.cellontology_id] = self.data.obs[self.adata_ids.cellontology_class] + self.data.obs[self.adata_ids.cell_type + self.adata_ids.onto_id_suffix] = \ + self.data.obs[self.adata_ids.cell_type] # - Add in custom features: self.data.var_names = ["dim_" + str(i) for i in range(self.data.n_vars)] - self.data.var[self.adata_ids.gene_id_ensembl] = ["dim_" + str(i) for i in range(self.data.n_vars)] - self.data.var[self.adata_ids.gene_id_symbols] = ["dim_" + str(i) for i in range(self.data.n_vars)] + self.data.var[self.adata_ids.feature_id] = ["dim_" + str(i) for i in range(self.data.n_vars)] + self.data.var[self.adata_ids.feature_symbol] = ["dim_" + str(i) for i in range(self.data.n_vars)] def init_topology_custom(self, model_type: str, n_features): topology = TOPOLOGY_CELLTYPE_MODEL.copy() @@ -323,10 +338,11 @@ def fatal_estimator_test_custom(self): model_id="testid", model_topology=self.tc, celltype_ontology=obo, + remove_unlabeled_cells=False, # TODO this should not be necessary but all cells are filtered otherwise ) self.estimator.init_model() self.estimator_train(test_split=0.1, randomized_batch_access=False) - self.basic_estimator_test(test_split=0.1) + self.basic_estimator_test() # Test embedding models: @@ -388,29 +404,26 @@ def test_dataset_size(batch_size: int, randomized_batch_access: bool): test_estim.load_estimator(model_type="linear", data_type="store", feature_space="reduced", test_split=0.2, organism="human") idx_train = test_estim.estimator.idx_train - shuffle_buffer_size = None if randomized_batch_access else 2 - ds_train = test_estim.estimator._get_dataset(idx=idx_train, batch_size=batch_size, mode='eval', - shuffle_buffer_size=shuffle_buffer_size, - retrieval_batch_size=retrieval_batch_size, - randomized_batch_access=randomized_batch_access) + ds_train = test_estim.estimator.get_one_time_tf_dataset(idx=idx_train, batch_size=batch_size, mode='eval') x_train_shape = 0 for x, _ in ds_train.as_numpy_iterator(): x_train_shape += x[0].shape[0] # Define raw store generator on train data to compare and check that it has the same size as tf generator exposed # by estimator: - g_train, _ = test_estim.estimator.data.generator(idx=idx_train, batch_size=retrieval_batch_size, - randomized_batch_access=randomized_batch_access) + g_train = test_estim.estimator.data.generator(idx=idx_train, retrieval_batch_size=retrieval_batch_size, + randomized_batch_access=randomized_batch_access) x_train2_shape = 0 - for x, _ in g_train(): + for x, _ in g_train.iterator(): + if len(x.shape) == 1: + x = np.expand_dims(x, axis=0) x_train2_shape += x.shape[0] assert x_train_shape == x_train2_shape assert x_train_shape == len(idx_train) @pytest.mark.parametrize("data_type", ["adata", "store"]) -@pytest.mark.parametrize("randomized_batch_access", [False, True]) @pytest.mark.parametrize("test_split", [0.3, {"id": "human_lung_2021_10xtechnology_mock1_001_no_doi_mock1"}]) -def test_split_index_sets(data_type: str, randomized_batch_access: bool, test_split): +def test_split_index_sets(data_type: str, test_split): """ Test that train, val, test split index sets are correct: @@ -422,8 +435,8 @@ def test_split_index_sets(data_type: str, randomized_batch_access: bool, test_sp test_estim = HelperEstimatorKerasEmbedding() # Need full feature space here because observations are not necessarily different in small model testing feature # space with only two genes: - test_estim.load_estimator(model_type="linear", data_type=data_type, test_split=test_split, feature_space="full", - organism="human") + test_estim.load_estimator(model_type="linear", data_type=data_type, feature_space="full", organism="human", + test_split=test_split) idx_train = test_estim.estimator.idx_train idx_eval = test_estim.estimator.idx_eval idx_test = test_estim.estimator.idx_test @@ -432,10 +445,10 @@ def test_split_index_sets(data_type: str, randomized_batch_access: bool, test_sp assert len(idx_train) == len(np.unique(idx_train)) assert len(idx_eval) == len(np.unique(idx_eval)) assert len(idx_test) == len(np.unique(idx_test)) - assert len(idx_train) + len(idx_eval) + len(idx_test) == test_estim.data.n_obs, \ - (len(idx_train), len(idx_eval), len(idx_test), test_estim.data.n_obs) + assert len(idx_train) + len(idx_eval) + len(idx_test) == test_estim.estimator.data.n_obs, \ + (len(idx_train), len(idx_eval), len(idx_test), test_estim.estimator.data.n_obs) if isinstance(test_estim.data, DistributedStoreSingleFeatureSpace): - assert np.sum([v.shape[0] for v in test_estim.data.adata_by_key.values()]) == test_estim.data.n_obs + assert np.sum([v.shape[0] for v in test_estim.data.indices.values()]) == test_estim.estimator.data.n_obs # 2) Assert that index assignments are exclusive to each split: assert len(set(idx_train).intersection(set(idx_eval))) == 0 assert len(set(idx_train).intersection(set(idx_test))) == 0 @@ -445,7 +458,7 @@ def test_split_index_sets(data_type: str, randomized_batch_access: bool, test_sp # Prepare data set-wise index vectors that are numbered in the same way as global split index vectors. idx_raw = [] counter = 0 - for v in test_estim.data.indices.values(): + for v in test_estim.estimator.data.indices.values(): idx_raw.append(np.arange(counter, counter + len(v))) counter += len(v) if isinstance(test_split, float): @@ -471,29 +484,15 @@ def test_split_index_sets(data_type: str, randomized_batch_access: bool, test_sp for x in idx_raw ])[0] assert np.all(datasets_train == datasets_eval), (datasets_train, datasets_eval, datasets_test) - assert len(set(datasets_train).intersection(set(datasets_test))) == 0, \ - (datasets_train, datasets_eval, datasets_test) + assert len(set(datasets_train).intersection(set(datasets_test))) == 0, (datasets_train, datasets_test) # 4) Assert that observations mapped to indices are actually unique based on expression vectors: # Build numpy arrays of expression input data sets from tensorflow data sets directly from estimator. # These data sets are the most processed transformation of the data and stand directly in concat with the model. - shuffle_buffer_size = None if randomized_batch_access else 2 - ds_train = test_estim.estimator._get_dataset(idx=idx_train, batch_size=1024, mode='eval', - shuffle_buffer_size=shuffle_buffer_size, - retrieval_batch_size=2048, - randomized_batch_access=randomized_batch_access) - ds_eval = test_estim.estimator._get_dataset(idx=idx_eval, batch_size=1024, mode='eval', - shuffle_buffer_size=shuffle_buffer_size, - retrieval_batch_size=2048, - randomized_batch_access=randomized_batch_access) - ds_test = test_estim.estimator._get_dataset(idx=idx_test, batch_size=1024, mode='eval', - shuffle_buffer_size=shuffle_buffer_size, - retrieval_batch_size=2048, - randomized_batch_access=randomized_batch_access) + ds_train = test_estim.estimator.get_one_time_tf_dataset(idx=idx_train, batch_size=1024, mode='eval') + ds_eval = test_estim.estimator.get_one_time_tf_dataset(idx=idx_eval, batch_size=1024, mode='eval') + ds_test = test_estim.estimator.get_one_time_tf_dataset(idx=idx_test, batch_size=1024, mode='eval') # Create two copies of test data set to make sure that re-instantiation of a subset does not cause issues. - ds_test2 = test_estim.estimator._get_dataset(idx=idx_test, batch_size=1024, mode='eval', - shuffle_buffer_size=shuffle_buffer_size, - retrieval_batch_size=2048, - randomized_batch_access=randomized_batch_access) + ds_test2 = test_estim.estimator.get_one_time_tf_dataset(idx=idx_test, batch_size=1024, mode='eval') x_train = [] x_eval = [] x_test = [] @@ -512,8 +511,8 @@ def test_split_index_sets(data_type: str, randomized_batch_access: bool, test_sp x_test2_shape += x[0].shape[0] assert x_test2_shape == x_test.shape[0] # Validate size of recovered numpy data sets: - assert x_train.shape[0] + x_eval.shape[0] + x_test.shape[0] == test_estim.data.n_obs - assert len(idx_train) + len(idx_eval) + len(idx_test) == test_estim.data.n_obs + assert x_train.shape[0] + x_eval.shape[0] + x_test.shape[0] == test_estim.estimator.data.n_obs + assert len(idx_train) + len(idx_eval) + len(idx_test) == test_estim.estimator.data.n_obs assert x_train.shape[0] == len(idx_train) assert x_eval.shape[0] == len(idx_eval) assert x_test.shape[0] == len(idx_test) diff --git a/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py b/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py index 3fdf83365..826999899 100644 --- a/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py +++ b/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py @@ -8,11 +8,25 @@ from sfaira.ui import ModelZoo from sfaira.versions.metadata import CelltypeUniverse, OntologyCl, OntologyUberon -from sfaira.unit_tests.tests_by_submodule.estimators import TestHelperEstimatorBase, TARGETS +from sfaira.unit_tests.tests_by_submodule.estimators import HelperEstimatorBase, TARGETS from sfaira.unit_tests import DIR_TEMP -class HelperTrainerBase(TestHelperEstimatorBase): +def get_cu(): + """ + Get file name of a target universe for loading by trainer. + """ + # Create temporary cell type universe to give to trainer. + fn = os.path.join(DIR_TEMP, "universe_temp.csv") + cl = OntologyCl(branch="v2021-02-01") + uberon = OntologyUberon() + cu = CelltypeUniverse(cl=cl, uberon=uberon) + cu.write_target_universe(fn=fn, x=TARGETS) + del cu + return fn + + +class HelperTrainerBase(HelperEstimatorBase): data: Union[anndata.AnnData, load_store] trainer: Union[TrainModelCelltype, TrainModelEmbedding] @@ -58,13 +72,7 @@ def test_save_embedding(): def test_save_celltypes(): - # Create temporary cell type universe to give to trainer. - tmp_fn = os.path.join(DIR_TEMP, "universe_temp.csv") - cl = OntologyCl(branch="v2021-02-01") - uberon = OntologyUberon() - cu = CelltypeUniverse(cl=cl, uberon=uberon) - cu.write_target_universe(fn=tmp_fn, x=TARGETS) - del cu + tmp_fn = get_cu() model_id = "celltype_human-lung-mlp-0.0.1-0.1_mylab" zoo = ModelZoo() zoo.model_id = model_id diff --git a/sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py b/sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py index 37fd6ba49..bc3eac6cd 100644 --- a/sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py +++ b/sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py @@ -6,7 +6,7 @@ from sfaira.unit_tests import DIR_TEMP -class TestUi: +class HelperUi: ui: Union[UserInterface] data: np.ndarray @@ -27,7 +27,7 @@ def simulate(self): """ pass - def _test_basic(self): + def test_basic(self): """ Test all relevant model methods. @@ -36,3 +36,11 @@ def _test_basic(self): """ temp_fn = os.path.join(DIR_TEMP, "test_data") self.ui = UserInterface(custom_repo=temp_fn, sfaira_repo=False) + + +def _test_for_fatal(): + """ + TODO need to simulate/add look up table as part of unit tests locally + """ + ui = HelperUi() + ui.test_basic() diff --git a/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py b/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py index 32b81fb78..3b847727b 100644 --- a/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py +++ b/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py @@ -1,6 +1,6 @@ import numpy as np -from sfaira.versions.metadata import OntologyUberon, OntologyCl, OntologyMondo, OntologyMmusdv, OntologyHsapdv, \ - OntologySinglecellLibraryConstruction +from sfaira.versions.metadata import OntologyUberon, OntologyCl, OntologyHancestro, OntologyHsapdv, OntologyMondo, \ + OntologyMmusdv, OntologySinglecellLibraryConstruction """ OntologyCelltypes @@ -11,7 +11,8 @@ def test_cl_loading(): """ Tests if ontology can be initialised. """ - _ = OntologyCl(branch="v2021-02-01") + _ = OntologyCl(branch="v2021-02-01", recache=True) + _ = OntologyCl(branch="v2021-02-01", recache=False) def test_cl_is_a(): @@ -73,8 +74,11 @@ def test_cl_set_leaves(): Hancestro """ -# def test_hancestro_loading(): -# _ = OntologyHancestro() + +def test_hancestro_loading(): + _ = OntologyHancestro(recache=True) + _ = OntologyHancestro(recache=False) + """ Hsapdv @@ -82,7 +86,8 @@ def test_cl_set_leaves(): def test_hsapdv_loading(): - _ = OntologyHsapdv() + _ = OntologyHsapdv(recache=True) + _ = OntologyHsapdv(recache=False) """ @@ -91,7 +96,8 @@ def test_hsapdv_loading(): def test_mondo_loading(): - _ = OntologyMondo() + _ = OntologyMondo(recache=True) + _ = OntologyMondo(recache=False) """ @@ -100,7 +106,8 @@ def test_mondo_loading(): def test_mmusdv_loading(): - _ = OntologyMmusdv() + _ = OntologyMmusdv(recache=True) + _ = OntologyMmusdv(recache=False) """ @@ -112,7 +119,8 @@ def test_sclc_loading(): """ Tests if ontology can be initialised. """ - _ = OntologySinglecellLibraryConstruction() + _ = OntologySinglecellLibraryConstruction(recache=True) + _ = OntologySinglecellLibraryConstruction(recache=False) def test_sclc_nodes(): @@ -147,7 +155,8 @@ def test_sclc_is_a(): def test_uberon_loading(): - _ = OntologyUberon() + _ = OntologyUberon(recache=True) + _ = OntologyUberon(recache=False) def test_uberon_subsetting(): diff --git a/sfaira/versions/genomes/genomes.py b/sfaira/versions/genomes/genomes.py index 0dedd4c01..728bb276a 100644 --- a/sfaira/versions/genomes/genomes.py +++ b/sfaira/versions/genomes/genomes.py @@ -291,6 +291,7 @@ class CustomFeatureContainer(GenomeContainer): def __init__( self, genome_tab: pandas.DataFrame, + organism: str, ): """ @@ -306,3 +307,8 @@ def __init__( assert KEY_ID in genome_tab.columns assert KEY_TYPE in genome_tab.columns self.genome_tab = genome_tab + self._organism = organism + + @property + def organism(self): + return self._organism diff --git a/sfaira/versions/metadata/__init__.py b/sfaira/versions/metadata/__init__.py index 098f3dc2f..c4848a854 100644 --- a/sfaira/versions/metadata/__init__.py +++ b/sfaira/versions/metadata/__init__.py @@ -1,4 +1,4 @@ from sfaira.versions.metadata.base import Ontology, OntologyList, OntologyHierarchical, OntologyObo, \ - OntologyOboCustom, OntologyCl, OntologyUberon, OntologyHsapdv, OntologyMondo, OntologyMmusdv, \ - OntologySinglecellLibraryConstruction, OntologyCellosaurus + OntologyOboCustom, OntologyCl, OntologyHancestro, OntologyUberon, OntologyHsapdv, OntologyMondo, \ + OntologyMmusdv, OntologySinglecellLibraryConstruction, OntologyCellosaurus from sfaira.versions.metadata.universe import CelltypeUniverse diff --git a/sfaira/versions/metadata/base.py b/sfaira/versions/metadata/base.py index ae9e420c9..c7ca1a919 100644 --- a/sfaira/versions/metadata/base.py +++ b/sfaira/versions/metadata/base.py @@ -183,7 +183,10 @@ def nodes_dict(self) -> dict: @property def node_names(self) -> List[str]: - return [x["name"] for x in self.graph.nodes.values()] + try: + return [x["name"] for x in self.graph.nodes.values()] + except KeyError as e: + raise KeyError(f"KeyError '{e}' in {type(self)}") @property def node_ids(self) -> List[str]: @@ -375,13 +378,22 @@ def __init__( recache: bool, **kwargs ): + # Note on base URL: EBI OLS points to different resources depending on the ontology used, this needs to be + # accounted for here. + if ontology == "hancestro": + base_url = f"https://www.ebi.ac.uk/ols/api/ontologies/{ontology}/terms/" \ + f"http%253A%252F%252Fpurl.obolibrary.org%252Fobo%252F" + elif ontology == "efo": + base_url = f"https://www.ebi.ac.uk/ols/api/ontologies/{ontology}/terms/" \ + f"http%253A%252F%252Fwww.ebi.ac.uk%252F{ontology}%252F" + else: + assert False + def get_url_self(iri): - return f"https://www.ebi.ac.uk/ols/api/ontologies/{ontology}/terms/" \ - f"http%253A%252F%252Fwww.ebi.ac.uk%252F{ontology}%252F{iri}" + return f"{base_url}{iri}" def get_url_children(iri): - return f"https://www.ebi.ac.uk/ols/api/ontologies/{ontology}/terms/" \ - f"http%253A%252F%252Fwww.ebi.ac.uk%252F{ontology}%252F{iri}/children" + return f"{base_url}{iri}/children" def get_iri_from_node(x): return x["iri"].split("/")[-1] @@ -413,7 +425,7 @@ def recursive_search(iri): direct_children = [] k_self = get_id_from_iri(iri) # Define root node if this is the first iteration, this node is otherwise not defined through values. - if k_self == "EFO:0010183": + if k_self == ":".join(root_term.split("_")): terms_self = requests.get(get_url_self(iri=iri)).json() nodes_new[k_self] = { "name": terms_self["label"], @@ -977,8 +989,31 @@ def synonym_node_properties(self) -> List[str]: return ["synonym"] +class OntologyHancestro(OntologyEbi): + + """ + TODO move this to .owl backend once available. + TODO root term: No term HANCESTRO_0001 ("Thing"?) accessible through EBI interface, because of that country-related + higher order terms are not available as they are parallel to HANCESTRO_0004. Maybe fix with .owl backend? + """ + + def __init__(self, recache: bool = False): + super().__init__( + ontology="hancestro", + root_term="HANCESTRO_0004", + additional_terms={}, + additional_edges=[], + ontology_cache_fn="hancestro.pickle", + recache=recache, + ) + + class OntologySinglecellLibraryConstruction(OntologyEbi): + """ + TODO CITE set not in API yet, added two nodes and edges temporarily. + """ + def __init__(self, recache: bool = False): super().__init__( ontology="efo", @@ -986,10 +1021,14 @@ def __init__(self, recache: bool = False): additional_terms={ "sci-plex": {"name": "sci-plex"}, "sci-RNA-seq": {"name": "sci-RNA-seq"}, + "EFO:0009294": {"name": "CITE-seq"}, # TODO not in API yet + "EFO:0030008": {"name": "CITE-seq (cell surface protein profiling)"}, # TODO not in API yet }, additional_edges=[ ("EFO:0010183", "sci-plex"), ("EFO:0010183", "sci-RNA-seq"), + ("EFO:0010183", "EFO:0009294"), # TODO not in API yet + ("EFO:0009294", "EFO:0030008"), # TODO not in API yet ], ontology_cache_fn="efo.pickle", recache=recache, From 0320b154e0561c38a3030c03df9ee4ae474c0bfe Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Tue, 7 Sep 2021 12:00:25 +0200 Subject: [PATCH 10/15] added uberon versioning (#354) * added uberon versioning --- sfaira/consts/ontologies.py | 3 ++- sfaira/unit_tests/tests_by_submodule/data/test_store.py | 3 ++- .../unit_tests/tests_by_submodule/trainer/test_trainer.py | 2 +- .../tests_by_submodule/versions/test_ontologies.py | 6 +++--- .../unit_tests/tests_by_submodule/versions/test_universe.py | 2 +- sfaira/versions/metadata/base.py | 3 ++- 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/sfaira/consts/ontologies.py b/sfaira/consts/ontologies.py index dbf8c44b5..8e0ed4b69 100644 --- a/sfaira/consts/ontologies.py +++ b/sfaira/consts/ontologies.py @@ -5,6 +5,7 @@ OntologyMmusdv, OntologySinglecellLibraryConstruction, OntologyUberon DEFAULT_CL = "v2021-02-01" +DEFAULT_UBERON = "2019-11-22" class OntologyContainerSfaira: @@ -126,5 +127,5 @@ def ethnicity(self): @property def organ(self): if self._organ is None: - self._organ = OntologyUberon() + self._organ = OntologyUberon(branch=DEFAULT_UBERON) return self._organ diff --git a/sfaira/unit_tests/tests_by_submodule/data/test_store.py b/sfaira/unit_tests/tests_by_submodule/data/test_store.py index 1a380f9a7..8b1e6a282 100644 --- a/sfaira/unit_tests/tests_by_submodule/data/test_store.py +++ b/sfaira/unit_tests/tests_by_submodule/data/test_store.py @@ -162,7 +162,8 @@ def test_generator_shapes(store_format: str, idx, batch_size: int, obs_keys: Lis """ Test generators queries do not throw errors and that output shapes are correct. """ - store_path = prepare_store(store_format=store_format) + # Need to re-write because specific obs_keys are required: + store_path = prepare_store(store_format=store_format, rewrite_store=True) store = load_store(cache_path=store_path, store_format=store_format) store.subset(attr_key="organism", values=["mouse"]) gc = GenomeContainer(assembly=ASSEMBLY_MOUSE) diff --git a/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py b/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py index 826999899..b4bb106b4 100644 --- a/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py +++ b/sfaira/unit_tests/tests_by_submodule/trainer/test_trainer.py @@ -19,7 +19,7 @@ def get_cu(): # Create temporary cell type universe to give to trainer. fn = os.path.join(DIR_TEMP, "universe_temp.csv") cl = OntologyCl(branch="v2021-02-01") - uberon = OntologyUberon() + uberon = OntologyUberon(branch="2019-11-22") cu = CelltypeUniverse(cl=cl, uberon=uberon) cu.write_target_universe(fn=fn, x=TARGETS) del cu diff --git a/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py b/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py index 3b847727b..6af897ccb 100644 --- a/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py +++ b/sfaira/unit_tests/tests_by_submodule/versions/test_ontologies.py @@ -155,12 +155,12 @@ def test_sclc_is_a(): def test_uberon_loading(): - _ = OntologyUberon(recache=True) - _ = OntologyUberon(recache=False) + _ = OntologyUberon(branch="2019-11-22", recache=True) + _ = OntologyUberon(branch="2019-11-22", recache=False) def test_uberon_subsetting(): - ou = OntologyUberon() + ou = OntologyUberon(branch="2019-11-22") assert ou.is_a(query="lobe of lung", reference="lung") assert ou.is_a(query="lobe of lung", reference="lobe of lung") assert not ou.is_a(query="lung", reference="lobe of lung") diff --git a/sfaira/unit_tests/tests_by_submodule/versions/test_universe.py b/sfaira/unit_tests/tests_by_submodule/versions/test_universe.py index 23222b613..08975a2b0 100644 --- a/sfaira/unit_tests/tests_by_submodule/versions/test_universe.py +++ b/sfaira/unit_tests/tests_by_submodule/versions/test_universe.py @@ -15,7 +15,7 @@ def test_universe_io(): targets = ["stromal cell", "lymphocyte", "T-helper 1 cell", "T-helper 17 cell"] leaves_target = ["stromal cell", "T-helper 1 cell", "T-helper 17 cell"] cl = OntologyCl(branch="v2021-02-01") - uberon = OntologyUberon() + uberon = OntologyUberon(branch="2019-11-22") cu = CelltypeUniverse(cl=cl, uberon=uberon) cu.write_target_universe(fn=tmp_fn, x=targets) cu.load_target_universe(fn=tmp_fn) diff --git a/sfaira/versions/metadata/base.py b/sfaira/versions/metadata/base.py index c7ca1a919..25d5553de 100644 --- a/sfaira/versions/metadata/base.py +++ b/sfaira/versions/metadata/base.py @@ -598,11 +598,12 @@ class OntologyUberon(OntologyExtendedObo): def __init__( self, + branch: str, recache: bool = False, **kwargs ): obofile = cached_load_obo( - url="http://purl.obolibrary.org/obo/uberon.obo", + url=f"https://svn.code.sf.net/p/obo/svn/uberon/releases/{branch}/ext.obo", ontology_cache_dir="uberon", ontology_cache_fn="uberon.obo", recache=recache, From 17d3eff16b591bf07710afb696125a3699e10920 Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Tue, 7 Sep 2021 12:29:56 +0200 Subject: [PATCH 11/15] added data life cycle rst (#355) --- docs/adding_datasets.rst | 3 +++ docs/consuming_data.rst | 4 ++++ docs/data_life_cycle.rst | 33 +++++++++++++++++++++++++++++++++ docs/distributed_data.rst | 3 +++ docs/index.rst | 1 + 5 files changed, 44 insertions(+) create mode 100644 docs/data_life_cycle.rst diff --git a/docs/adding_datasets.rst b/docs/adding_datasets.rst index e39411e45..742e42f46 100644 --- a/docs/adding_datasets.rst +++ b/docs/adding_datasets.rst @@ -1,6 +1,9 @@ +.. _adding_data_rst: + Contributing data ================== +For a high-level overview of data management in sfaira, read :ref:`data_life_cycle_rst` first. Adding datasets to sfaira is a great way to increase the visibility of your dataset and to make it available to a large audience. This process requires a couple of steps as outlined in the following sections. diff --git a/docs/consuming_data.rst b/docs/consuming_data.rst index beb5a2466..0f69c486b 100644 --- a/docs/consuming_data.rst +++ b/docs/consuming_data.rst @@ -1,3 +1,5 @@ +.. _consuming_data_rst: + Consuming data =============== @@ -5,6 +7,8 @@ Consuming data :width: 600px :align: center +For a high-level overview of data management in sfaira, read :ref:`data_life_cycle_rst` first. + Build data repository locally ------------------------------ diff --git a/docs/data_life_cycle.rst b/docs/data_life_cycle.rst new file mode 100644 index 000000000..4d5ae60bd --- /dev/null +++ b/docs/data_life_cycle.rst @@ -0,0 +1,33 @@ +.. _data_life_cycle_rst: + +The data life cycle +=================== + +The life cycle of a single-cell count matrix often looks as follows: + + 1. **Generation** from primary read data in a read alignment pipeline. + 2. **Annotation** with cell types and sample meta data. + 3. **Publication** of annotated data, often together with a manuscript. + 4. **Curation** of this public data set for the purpose of a meta study. In a python workflow, this curation step could be a scanpy script based on data from step 3, for example. + 5. **Usage** of data curated specifically for the use case at hand, for example for a targeted analysis or a training of a machine learning model. + +where step 1-3 is often only performed once by the original authors of the data set, +while step 4 and 5 are repeated multiple times in the community for different meta studies. +Sfaira offers the following functionality groups that accelerate steps along this pipeline: + +I) Data loaders +~~~~~~~~~~~~~~~ +We maintain streamlined data loader code that improve **Curation** (step 4) and make this step sharable and iteratively improvable. +Read more in our guide to data contribution :ref:`adding_data_rst`. + +II) Dataset, DatasetGroup, DatasetSuperGroup +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using the data loaders from (I), we built an interface that can flexibly download, subset and curate data sets from the sfaira data zoo, thus improving **Usage** (step 5). +This interface can yield adata instances to be used in a scanpy pipeline, for example. +Read more in our guide to data consumption :ref:`consuming_data_rst`. + +III) Stores +~~~~~~~~~~~ +Using the streamlined data set collections from (II), we built a computationally efficient data interface for machine learning on such large distributed data set collection, thus improving **Usage** (step 5): +Specifically, this interface is optimised for out-of-core observation-centric indexing in scenarios that are typical to machine learning on single-cell data. +Read more in our guide to data stores :ref:`distributed_data_rst`. diff --git a/docs/distributed_data.rst b/docs/distributed_data.rst index 02b0ca127..911a1e7db 100644 --- a/docs/distributed_data.rst +++ b/docs/distributed_data.rst @@ -1,6 +1,9 @@ +.. _distributed_data_rst: + Distributed data ================ +For a high-level overview of data management in sfaira, read :ref:`data_life_cycle_rst` first. Sfaira supports usage of distributed data for model training and execution. The tools are summarized under `sfaira.data.store`. In contrast to using an instance of AnnData in memory, these tools can be used to use data sets that are saved diff --git a/docs/index.rst b/docs/index.rst index 858632fec..133bbe6c4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,7 @@ Latest additions api commandline_interface tutorials + data_life_cycle adding_datasets consuming_data distributed_data From 2ef1e01e31e132ce12f334227a91be2cb6a01bee Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Tue, 7 Sep 2021 13:24:27 +0200 Subject: [PATCH 12/15] Release 0.3.5 (#357) From 74c206af960cd9704aa41aa10cb37115844bd86a Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Wed, 8 Sep 2021 14:12:06 +0200 Subject: [PATCH 13/15] Bug/universe loading (#359) * added more verbose error to empty file handle * depreceated copy in project_free_to_ontology and fixed bug that arose centred on copy * added obs surveying script --- sfaira/data/dataloaders/base/dataset.py | 26 +++++++------------ sfaira/data/dataloaders/base/dataset_group.py | 2 +- .../data/dataloaders/loaders/super_group.py | 15 ++++++----- .../utils_scripts/survey_obs_annotation.py | 26 +++++++++++++++++++ 4 files changed, 45 insertions(+), 24 deletions(-) create mode 100644 sfaira/data/utils_scripts/survey_obs_annotation.py diff --git a/sfaira/data/dataloaders/base/dataset.py b/sfaira/data/dataloaders/base/dataset.py index 5aa0b7adc..c2de82a52 100644 --- a/sfaira/data/dataloaders/base/dataset.py +++ b/sfaira/data/dataloaders/base/dataset.py @@ -1120,17 +1120,12 @@ def read_ontology_class_map(self, fn): if self.cell_type_obs_key is not None: warnings.warn(f"file {fn} does not exist but cell_type_obs_key {self.cell_type_obs_key} is given") - def project_free_to_ontology(self, attr: str, copy: bool = False): + def project_free_to_ontology(self, attr: str): """ Project free text cell type names to ontology based on mapping table. ToDo: add ontology ID setting here. ToDo: only for cell type right now, extend to other meta data in the future. - - :param copy: If True, a dataframe with the celltype annotation is returned, otherwise self.adata.obs is updated - inplace. - - :return: """ ontology_map = attr + "_map" if hasattr(self, ontology_map): @@ -1139,7 +1134,6 @@ def project_free_to_ontology(self, attr: str, copy: bool = False): ontology_map = None print(f"WARNING: did not find ontology map for {attr} which was only defined by free annotation") adata_fields = self._adata_ids - results = {} col_original = attr + adata_fields.onto_original_suffix labels_original = self.adata.obs[col_original].values if ontology_map is not None: # only if this was defined @@ -1173,19 +1167,17 @@ def project_free_to_ontology(self, attr: str, copy: bool = False): # TODO this could be changed in the future, this allows this function to be used both on cell type name # mapping files with and without the ID in the third column. # This mapping blocks progression in the unit test if not deactivated. - results[getattr(adata_fields, attr)] = labels_mapped + self.adata.obs[getattr(adata_fields, attr)] = labels_mapped self.__project_ontology_ids_obs(attr=attr, map_exceptions=map_exceptions, from_id=False, adata_ids=adata_fields) else: - results[getattr(adata_fields, attr)] = labels_original - results[getattr(adata_fields, attr) + adata_fields.onto_id_suffix] = \ + # Assumes that the original labels are the correct ontology symbols, because of a lack of ontology, + # ontology IDs cannot be inferred. + # TODO is this necessary in the future? + self.adata.obs[getattr(adata_fields, attr)] = labels_original + self.adata.obs[getattr(adata_fields, attr) + adata_fields.onto_id_suffix] = \ [adata_fields.unknown_metadata_identifier] * self.adata.n_obs - results[getattr(adata_fields, attr) + adata_fields.onto_original_suffix] = labels_original - if copy: - return pd.DataFrame(results, index=self.adata.obs.index) - else: - for k, v in results.items(): - self.adata.obs[k] = v + self.adata.obs[getattr(adata_fields, attr) + adata_fields.onto_original_suffix] = labels_original def __impute_ontology_cols_obs( self, @@ -1238,7 +1230,7 @@ def __impute_ontology_cols_obs( # Original annotation (free text): original_present = col_original in self.adata.obs.columns if original_present and not symbol_present and not id_present: # 1) - self.project_free_to_ontology(attr=attr, copy=False) + self.project_free_to_ontology(attr=attr) if symbol_present or id_present: # 2) if symbol_present and not id_present: # 2a) self.__project_ontology_ids_obs(attr=attr, from_id=False, adata_ids=adata_ids) diff --git a/sfaira/data/dataloaders/base/dataset_group.py b/sfaira/data/dataloaders/base/dataset_group.py index 56f213c3b..b7dc5a042 100644 --- a/sfaira/data/dataloaders/base/dataset_group.py +++ b/sfaira/data/dataloaders/base/dataset_group.py @@ -678,7 +678,7 @@ def __init__( elif package_source == "sfaira_extension": package_source = "sfairae" else: - raise ValueError(f"invalid package source {package_source} for {self._cwd}, {self.collection_id}") + raise ValueError(f"invalid package source {package_source} for {self._cwd}") except IndexError as e: raise IndexError(f"{e} for {self._cwd}") loader_pydoc_path_sfaira = "sfaira.data.dataloaders.loaders." diff --git a/sfaira/data/dataloaders/loaders/super_group.py b/sfaira/data/dataloaders/loaders/super_group.py index 2ee26a1b1..e5354ac5e 100644 --- a/sfaira/data/dataloaders/loaders/super_group.py +++ b/sfaira/data/dataloaders/loaders/super_group.py @@ -35,12 +35,15 @@ def __init__( if f[:len(dir_prefix)] == dir_prefix and f not in dir_exclude: # Narrow down to data set directories path_dsg = str(pydoc.locate(f"sfaira.data.dataloaders.loaders.{f}.FILE_PATH")) if path_dsg is not None: - dataset_groups.append(DatasetGroupDirectoryOriented( - file_base=path_dsg, - data_path=data_path, - meta_path=meta_path, - cache_path=cache_path - )) + try: + dataset_groups.append(DatasetGroupDirectoryOriented( + file_base=path_dsg, + data_path=data_path, + meta_path=meta_path, + cache_path=cache_path + )) + except IndexError as e: + raise IndexError(f"{e} for '{cwd}', '{f}', '{path_dsg}'") else: warn(f"DatasetGroupDirectoryOriented was None for {f}") super().__init__(dataset_groups=dataset_groups) diff --git a/sfaira/data/utils_scripts/survey_obs_annotation.py b/sfaira/data/utils_scripts/survey_obs_annotation.py new file mode 100644 index 000000000..80d66c08f --- /dev/null +++ b/sfaira/data/utils_scripts/survey_obs_annotation.py @@ -0,0 +1,26 @@ +import numpy as np +import sfaira +import sys + +# Set global variables. +print("sys.argv", sys.argv) + +data_path = str(sys.argv[1]) +path_meta = str(sys.argv[2]) +path_cache = str(sys.argv[3]) + +universe = sfaira.data.dataloaders.Universe( + data_path=data_path, meta_path=path_meta, cache_path=path_cache +) +for k, v in universe.datasets.items(): + print(k) + v.load( + load_raw=False, + allow_caching=True, + ) + for col in v.adata.obs.columns: + val = np.sort(np.unique(v.adata.obs[col].values)) + if len(val) > 20: + val = val[:20] + print(f"{k}: {col}: {val}") + v.clear() From 2b66d04a3c712c9fd088dca037f359219423cf58 Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Wed, 8 Sep 2021 17:31:19 +0200 Subject: [PATCH 14/15] Bug/ontology celltypes (#360) * fixed access to ontology in dataset group --- sfaira/data/dataloaders/base/dataset_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sfaira/data/dataloaders/base/dataset_group.py b/sfaira/data/dataloaders/base/dataset_group.py index b7dc5a042..7000d6a9d 100644 --- a/sfaira/data/dataloaders/base/dataset_group.py +++ b/sfaira/data/dataloaders/base/dataset_group.py @@ -512,7 +512,7 @@ def ontology_celltypes(self): # ToDo: think about whether this should be handled differently. warnings.warn("found more than one organism in group, this could cause problems with using a joined cell " "type ontology. Using only the ontology of the first data set in the group.") - return self.datasets[self.ids[0]].ontology_celltypes + return self.datasets[self.ids[0]].ontology_container_sfaira.cell_type def project_celltypes_to_ontology(self, adata_fields: Union[AdataIds, None] = None, copy=False): """ From e23a87800727b65d21b78fe6c318b6af4763c824 Mon Sep 17 00:00:00 2001 From: "David S. Fischer" Date: Thu, 9 Sep 2021 20:51:00 +0200 Subject: [PATCH 15/15] Feature/map fn vectorisation (#364) * enabled map_fn vectorisation in generators * sped up store access --- sfaira/data/store/batch_schedule.py | 18 +-- sfaira/data/store/generators.py | 108 +++++++++++------- sfaira/data/store/single_store.py | 9 +- sfaira/estimators/keras.py | 67 ++++++----- sfaira/models/celltype/marker.py | 2 +- sfaira/models/celltype/mlp.py | 2 +- sfaira/models/embedding/ae.py | 2 +- sfaira/models/embedding/linear.py | 2 +- .../tests_by_submodule/data/test_store.py | 20 ++-- .../estimators/test_estimator.py | 9 +- 10 files changed, 135 insertions(+), 104 deletions(-) diff --git a/sfaira/data/store/batch_schedule.py b/sfaira/data/store/batch_schedule.py index 9a7f86e03..6a2a10718 100644 --- a/sfaira/data/store/batch_schedule.py +++ b/sfaira/data/store/batch_schedule.py @@ -24,6 +24,7 @@ class BatchDesignBase: def __init__(self, retrieval_batch_size: int, randomized_batch_access: bool, random_access: bool, **kwargs): self.retrieval_batch_size = retrieval_batch_size + self._batch_bounds = None self._idx = None if randomized_batch_access and random_access: raise ValueError("Do not use randomized_batch_access and random_access.") @@ -51,7 +52,7 @@ def idx(self, x): self._idx = np.sort(x) # Sorted indices improve accession efficiency in some cases. @property - def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + def design(self) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]: """ Yields index objects for one epoch of all data. @@ -59,7 +60,8 @@ def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: Randomization is performed anew with every call to this property. :returns: Tuple of: - - Ordering of observations in epoch. + - Indices of observations in epoch out of selected data set (ie indices of self.idx). + - Ordering of observations in epoch out of full data set. - Batch start and end indices for batch based on ordering defined in first output. """ raise NotImplementedError() @@ -68,14 +70,14 @@ def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: class BatchDesignBasic(BatchDesignBase): @property - def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: - idx_proc = self.idx.copy() + def design(self) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]: + idx_proc = np.arange(0, len(self.idx)) if self.random_access: np.random.shuffle(idx_proc) batch_bounds = self.batch_bounds.copy() if self.randomized_batch_access: batch_bounds = _randomize_batch_start_ends(batch_starts_ends=batch_bounds) - return idx_proc, batch_bounds + return idx_proc, self.idx[idx_proc], batch_bounds class BatchDesignBalanced(BatchDesignBase): @@ -110,15 +112,15 @@ def __init__(self, grouping, group_weights: dict, randomized_batch_access: bool, self.p_obs = p_obs @property - def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + def design(self) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]: # Re-sample index vector. - idx_proc = np.random.choice(a=self.idx, replace=True, size=len(self.idx), p=self.p_obs) + idx_proc = np.random.choice(a=np.arange(0, len(self.idx)), replace=True, size=len(self.idx), p=self.p_obs) if not self.random_access: # Note: randomization is result from sampling above, need to revert if not desired. idx_proc = np.sort(idx_proc) batch_bounds = self.batch_bounds.copy() if self.randomized_batch_access: batch_bounds = _randomize_batch_start_ends(batch_starts_ends=batch_bounds) - return idx_proc, batch_bounds + return idx_proc, self.idx[idx_proc], batch_bounds BATCH_SCHEDULE = { diff --git a/sfaira/data/store/generators.py b/sfaira/data/store/generators.py index a72ee9f56..173c94e76 100644 --- a/sfaira/data/store/generators.py +++ b/sfaira/data/store/generators.py @@ -8,14 +8,24 @@ from sfaira.data.store.batch_schedule import BATCH_SCHEDULE -def split_batch(x, obs): +def split_batch(x): """ Splits retrieval batch into consumption batches of length 1. Often, end-user consumption batches would be observation-wise, ie yield a first dimension of length 1. + + :param x: Data tuple of length 1 or 2: (input,) or (input, output,), where both input and output are also + a tuple, but of batch-dimensioned tensors. """ - for i in range(x.shape[0]): - yield x[i, :], obs.iloc[[i], :] + batch_dim = x[0][0].shape[0] + for i in range(batch_dim): + output = [] + for y in x: + if isinstance(y, tuple): + output.append(tuple([z[i, :] for z in y])) + else: + output.append(y[i, :]) + yield tuple(output) class GeneratorBase: @@ -92,6 +102,7 @@ def __init__(self, batch_schedule, batch_size, map_fn, obs_idx, obs_keys, var_id in self.adata_by_key. :param var_idx: The features to emit. """ + self.var_idx = var_idx self._obs_idx = None if not batch_size == 1: raise ValueError(f"Only batch size==1 is supported, found {batch_size}.") @@ -103,7 +114,6 @@ def __init__(self, batch_schedule, batch_size, map_fn, obs_idx, obs_keys, var_id self.schedule = batch_schedule(**kwargs) self.obs_idx = obs_idx self.obs_keys = obs_keys - self.var_idx = var_idx def _validate_idx(self, idx: Union[np.ndarray, list]) -> np.ndarray: """ @@ -170,7 +180,7 @@ def iterator(self) -> iter: # Speed up access to single object by skipping index overlap operations: def g(): - obs_idx, batch_bounds = self.schedule.design + _, obs_idx, batch_bounds = self.schedule.design for s, e in batch_bounds: idx_i = obs_idx[s:e] # Match adata objects that overlap to batch: @@ -204,13 +214,9 @@ def g(): x = x[:, self.var_idx] # Prepare .obs. obs = self.adata_dict[k].obs[self.obs_keys].iloc[v, :] - for x_i, obs_i in split_batch(x=x, obs=obs): - if self.map_fn is None: - yield x_i, obs_i - else: - output = self.map_fn(x_i, obs_i) - if output is not None: - yield output + data_tuple = self.map_fn(x, obs) + for data_tuple_i in split_batch(x=data_tuple): + yield data_tuple_i else: # Concatenates slices first before returning. Note that this is likely slower than emitting by # observation in most scenarios. @@ -250,64 +256,82 @@ def g(): self.adata_dict[k].obs[self.obs_keys].iloc[v, :] for k, v in idx_i_dict.items() ], axis=0, join="inner", ignore_index=True, copy=False) - if self.map_fn is None: - yield x, obs - else: - output = self.map_fn(x, obs) - if output is not None: - yield output + data_tuple = self.map_fn(x, obs) + yield data_tuple return g class GeneratorDask(GeneratorSingle): + """ + In addition to the full data array, x, this class maintains a slice _x_slice which is indexed by the iterator. + Access to the slice can be optimised with dask and is therefore desirable. + """ + x: dask.array + _x_slice: dask.array obs: pd.DataFrame - def __init__(self, x, obs, **kwargs): + def __init__(self, x, obs, obs_keys, var_idx, **kwargs): + if var_idx is not None: + x = x[:, var_idx] self.x = x - super(GeneratorDask, self).__init__(**kwargs) - self.obs = obs[self.obs_keys] + self.obs = obs[obs_keys] # Redefine index so that .loc indexing can be used instead of .iloc indexing: self.obs.index = np.arange(0, obs.shape[0]) + self._x_slice = None + self._obs_slice = None + super(GeneratorDask, self).__init__(obs_keys=obs_keys, var_idx=var_idx, **kwargs) @property def n_obs(self) -> int: return self.x.shape[0] + @property + def obs_idx(self): + return self._obs_idx + + @obs_idx.setter + def obs_idx(self, x): + """ + Allows emission of different iterator on same generator instance (using same dask array). + In addition to base method: allows for optimisation of dask array for batch draws. + """ + if x is None: + x = np.arange(0, self.n_obs) + else: + x = self._validate_idx(x) + x = np.sort(x) + # Only reset if they are actually different: + if (self._obs_idx is not None and len(x) != len(self._obs_idx)) or np.any(x != self._obs_idx): + self._obs_idx = x + self.schedule.idx = x + self._x_slice = dask.optimize(self.x[self._obs_idx, :])[0] + self._obs_slice = self.obs.loc[self.obs.index[self._obs_idx], :] # TODO better than iloc? + # Redefine index so that .loc indexing can be used instead of .iloc indexing: + self._obs_slice.index = np.arange(0, self._obs_slice.shape[0]) + @property def iterator(self) -> iter: # Can all data sets corresponding to one organism as a single array because they share the second dimension # and dask keeps expression data and obs out of memory. def g(): - obs_idx, batch_bounds = self.schedule.design - x_temp = self.x[obs_idx, :] - obs_temp = self.obs.loc[self.obs.index[obs_idx], :] # TODO better than iloc? + obs_idx_slice, _, batch_bounds = self.schedule.design + x_temp = self._x_slice + obs_temp = self._obs_slice for s, e in batch_bounds: - x_i = x_temp[s:e, :] - if self.var_idx is not None: - x_i = x_i[:, self.var_idx] + x_i = x_temp[obs_idx_slice[s:e], :] # Exploit fact that index of obs is just increasing list of integers, so we can use the .loc[] # indexing instead of .iloc[]: - obs_i = obs_temp.loc[obs_temp.index[s:e], :] - # TODO place map_fn outside of for loop so that vectorisation in preprocessing can be used. + obs_i = obs_temp.loc[obs_temp.index[obs_idx_slice[s:e]], :] + data_tuple = self.map_fn(x_i, obs_i) if self.batch_size == 1: - for x_ii, obs_ii in split_batch(x=x_i, obs=obs_i): - if self.map_fn is None: - yield x_ii, obs_ii - else: - output = self.map_fn(x_ii, obs_ii) - if output is not None: - yield output + for data_tuple_i in split_batch(x=data_tuple): + yield data_tuple_i else: - if self.map_fn is None: - yield x_i, obs_i - else: - output = self.map_fn(x_i, obs_i) - if output is not None: - yield output + yield data_tuple return g diff --git a/sfaira/data/store/single_store.py b/sfaira/data/store/single_store.py index 4f4cd19ca..284069e64 100644 --- a/sfaira/data/store/single_store.py +++ b/sfaira/data/store/single_store.py @@ -583,15 +583,20 @@ def X_slice(self, idx: np.ndarray, as_sparse: bool = True, **kwargs) -> Union[np :return: Slice of data array. """ batch_size = min(len(idx), 128) + + def map_fn(x, obs): + return (x, ), + g = self.generator(idx=idx, retrieval_batch_size=batch_size, return_dense=True, random_access=False, - randomized_batch_access=False, **kwargs) + randomized_batch_access=False, map_fn=map_fn, **kwargs) shape = (idx.shape[0], self.n_vars) if as_sparse: x = scipy.sparse.csr_matrix(np.zeros(shape)) else: x = np.empty(shape) counter = 0 - for x_batch, _ in g.iterator(): + for x_batch, in g.iterator(): + x_batch = x_batch[0] batch_len = x_batch.shape[0] x[counter:(counter + batch_len), :] = x_batch counter += batch_len diff --git a/sfaira/estimators/keras.py b/sfaira/estimators/keras.py index 5888134ec..2aa3d4154 100644 --- a/sfaira/estimators/keras.py +++ b/sfaira/estimators/keras.py @@ -2,7 +2,6 @@ import anndata import hashlib import numpy as np -import scipy.sparse try: import tensorflow as tf except ImportError: @@ -31,12 +30,7 @@ def prepare_sf(x): """ Uses a minimal size factor of 1e-3 for total counts / 1e4 """ - if len(x.shape) == 2: - sf = np.asarray(x.sum(axis=1)).flatten() - elif len(x.shape) == 1: - sf = np.asarray(x.sum()).flatten() - else: - raise ValueError("x.shape > 2") + sf = np.asarray(x.sum(axis=1, keepdims=True)) sf = np.log(np.maximum(sf / 1e4, 1e-3)) return sf @@ -565,15 +559,17 @@ def init_model( def _tf_dataset_kwargs(self, mode: str): # Determine model type [ae, vae(iaf, vamp)] model_type = "vae" if self.model_type[:3] == "vae" else "ae" + output_types_x = (tf.float32, tf.float32) + output_shapes_x = (self.data.n_vars, 1) if mode == 'predict': # Output shape is same for predict mode regardless of model type - output_types = (tf.float32, tf.float32), - output_shapes = (self.data.n_vars, ()), + output_types = output_types_x, + output_shapes = output_shapes_x, elif model_type == "vae": - output_types = ((tf.float32, tf.float32), (tf.float32, tf.float32)) - output_shapes = ((self.data.n_vars, ()), (self.data.n_vars, ())) + output_types = (output_types_x, (tf.float32, tf.float32)) + output_shapes = (output_shapes_x, (self.data.n_vars, 1)) else: - output_types = ((tf.float32, tf.float32), tf.float32) - output_shapes = ((self.data.n_vars, ()), self.data.n_vars) + output_types = (output_types_x, (tf.float32, )) + output_shapes = (output_shapes_x, (self.data.n_vars, )) return {"output_types": output_types, "output_shapes": output_shapes} def _get_generator( @@ -588,14 +584,15 @@ def _get_generator( model_type = "vae" if self.model_type[:3] == "vae" else "ae" def map_fn(x_sample, obs_sample): - x_sample = np.asarray(x_sample).flatten() - sf_sample = prepare_sf(x=x_sample).flatten()[0] + x_sample = np.asarray(x_sample) + sf_sample = prepare_sf(x=x_sample) + output_x = (x_sample, sf_sample) if mode == 'predict': - output = (x_sample, sf_sample), + output = output_x, elif model_type == "vae": - output = (x_sample, sf_sample), (x_sample, sf_sample), + output = output_x, (x_sample, sf_sample) else: - output = (x_sample, sf_sample), x_sample + output = output_x, (x_sample, ) return output g = self.data.generator(idx=idx, retrieval_batch_size=retrieval_batch_size, obs_keys=[], map_fn=map_fn, @@ -924,16 +921,14 @@ def _get_celltype_out( return weights, y def _tf_dataset_kwargs(self, mode): + output_types_x = (tf.float32,) + output_shapes_x = (tf.TensorShape([self.data.n_vars]), ) if mode == 'predict': - output_types = (tf.float32,) - output_shapes = (tf.TensorShape([self.data.n_vars]),) + output_types = (output_types_x, ) + output_shapes = (output_shapes_x, ) else: - output_types = (tf.float32, tf.float32, tf.float32) - output_shapes = ( - (tf.TensorShape([self.data.n_vars])), - tf.TensorShape([self.ntypes]), - tf.TensorShape([]) - ) + output_types = (output_types_x, tf.float32) + output_shapes = (output_shapes_x, tf.TensorShape([self.ntypes])) return {"output_types": output_types, "output_shapes": output_shapes} def _get_generator( @@ -953,16 +948,20 @@ def _get_generator( onehot_encoder = self._one_hot_encoder() def map_fn(x_sample, obs_sample): - x_sample = np.asarray(x_sample).flatten() + x_sample = np.asarray(x_sample) + output_x = (x_sample, ) if yield_labels: y_sample = onehot_encoder(obs_sample[self._adata_ids.cell_type + self._adata_ids.onto_id_suffix].values) - y_sample = y_sample.flatten() - if y_sample.sum() > 0: - output = x_sample, y_sample, 1. - else: - output = None + # Only yield observations with valid label: + idx_keep = y_sample.sum(axis=1) > 0. + if not np.all(idx_keep): + idx_keep = np.where(idx_keep)[0] + output_x = tuple([x[idx_keep, :] for x in output_x]) + y_sample = y_sample[idx_keep, :] + output_y = y_sample + output = output_x, output_y else: - output = x_sample, + output = output_x, return output g = self.data.generator(idx=idx, retrieval_batch_size=retrieval_batch_size, @@ -1009,7 +1008,7 @@ def ytrue(self, batch_size: int = 128, max_steps: int = np.inf): if len(idx) > 0: dataset = self.get_one_time_tf_dataset(idx=idx, batch_size=batch_size, mode='eval') y_true = [] - for _, y, _ in dataset.as_numpy_iterator(): + for _, y in dataset.as_numpy_iterator(): y_true.append(y) y_true = np.concatenate(y_true, axis=0) return y_true diff --git a/sfaira/models/celltype/marker.py b/sfaira/models/celltype/marker.py index 1b3c342ad..5514f3487 100644 --- a/sfaira/models/celltype/marker.py +++ b/sfaira/models/celltype/marker.py @@ -93,7 +93,7 @@ def __init__( kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )(x) - self.training_model = tf.keras.Model(inputs=inputs, outputs=y, name=name) + self.training_model = tf.keras.Model(inputs=[inputs, ], outputs=y, name=name) class CellTypeMarkerVersioned(CellTypeMarker): diff --git a/sfaira/models/celltype/mlp.py b/sfaira/models/celltype/mlp.py index f846c9131..e69b76063 100644 --- a/sfaira/models/celltype/mlp.py +++ b/sfaira/models/celltype/mlp.py @@ -69,7 +69,7 @@ def __init__( kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )(x) - self.training_model = tf.keras.Model(inputs=inputs, outputs=x, name=name) + self.training_model = tf.keras.Model(inputs=[inputs, ], outputs=x, name=name) class CellTypeMlpVersioned(CellTypeMlp): diff --git a/sfaira/models/embedding/ae.py b/sfaira/models/embedding/ae.py index 428132830..cf75162f0 100644 --- a/sfaira/models/embedding/ae.py +++ b/sfaira/models/embedding/ae.py @@ -194,7 +194,7 @@ def __init__( ) self.training_model = tf.keras.Model( inputs=[inputs_encoder, inputs_sf], - outputs=output_decoder_expfamily_concat, + outputs=[output_decoder_expfamily_concat], name="autoencoder" ) diff --git a/sfaira/models/embedding/linear.py b/sfaira/models/embedding/linear.py index cea092bfe..9d2ec3932 100644 --- a/sfaira/models/embedding/linear.py +++ b/sfaira/models/embedding/linear.py @@ -91,7 +91,7 @@ def __init__( ) self.training_model = tf.keras.Model( inputs=[inputs_encoder, inputs_sf], - outputs=output_decoder_expfamily_concat, + outputs=[output_decoder_expfamily_concat], name="autoencoder" ) diff --git a/sfaira/unit_tests/tests_by_submodule/data/test_store.py b/sfaira/unit_tests/tests_by_submodule/data/test_store.py index 8b1e6a282..e95b410d6 100644 --- a/sfaira/unit_tests/tests_by_submodule/data/test_store.py +++ b/sfaira/unit_tests/tests_by_submodule/data/test_store.py @@ -163,38 +163,34 @@ def test_generator_shapes(store_format: str, idx, batch_size: int, obs_keys: Lis Test generators queries do not throw errors and that output shapes are correct. """ # Need to re-write because specific obs_keys are required: - store_path = prepare_store(store_format=store_format, rewrite_store=True) + store_path = prepare_store(store_format=store_format) store = load_store(cache_path=store_path, store_format=store_format) store.subset(attr_key="organism", values=["mouse"]) gc = GenomeContainer(assembly=ASSEMBLY_MOUSE) gc.subset(**{"biotype": "protein_coding"}) store.genome_container = gc - g = store.generator( - idx={"mouse": idx}, - batch_size=batch_size, - obs_keys=obs_keys, - randomized_batch_access=randomized_batch_access, - ) + + def map_fn(x, obs): + return (x, ), + + g = store.generator(idx={"mouse": idx}, batch_size=batch_size, map_fn=map_fn, obs_keys=obs_keys, + randomized_batch_access=randomized_batch_access) g = g.iterator nobs = len(idx) if idx is not None else store.n_obs batch_sizes = [] x = None - obs = None counter = 0 for i, z in enumerate(g()): counter += 1 - x_i, obs_i = z + x_i, = z[0] if len(x_i.shape) == 1: # x is flattened if batch size is 1: assert batch_size == 1 x_i = np.expand_dims(x_i, axis=0) - assert x_i.shape[0] == obs_i.shape[0] if i == 0: x = x_i - obs = obs_i batch_sizes.append(x_i.shape[0]) assert counter > 0 assert x.shape[1] == store.n_vars["mouse"], (x.shape, store.n_vars["mouse"]) - assert obs.shape[1] == len(obs_keys), (obs.shape, obs_keys) assert np.sum(batch_sizes) == nobs, (batch_sizes, nobs) assert x.shape[1] == gc.n_var, (x.shape, gc.n_var) diff --git a/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py b/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py index 3dce30567..db04442f0 100644 --- a/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py +++ b/sfaira/unit_tests/tests_by_submodule/estimators/test_estimator.py @@ -410,10 +410,15 @@ def test_dataset_size(batch_size: int, randomized_batch_access: bool): x_train_shape += x[0].shape[0] # Define raw store generator on train data to compare and check that it has the same size as tf generator exposed # by estimator: + + def map_fn(x, obs): + return (x, ), + g_train = test_estim.estimator.data.generator(idx=idx_train, retrieval_batch_size=retrieval_batch_size, - randomized_batch_access=randomized_batch_access) + randomized_batch_access=randomized_batch_access, map_fn=map_fn) x_train2_shape = 0 - for x, _ in g_train.iterator(): + for x, in g_train.iterator(): + x = x[0] if len(x.shape) == 1: x = np.expand_dims(x, axis=0) x_train2_shape += x.shape[0]