Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed store to estimator interface and added further unit tests on st… #255

Merged
merged 16 commits into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions sfaira/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,8 @@ def streamline_metadata(
:param schema: Export format.
- "sfaira"
- "cellxgene"
:param uns_to_obs: Whether to move metadata in .uns to .obs to make sure it's not lost when concatenating multiple objects.
:param uns_to_obs: Whether to move metadata in .uns to .obs to make sure it's not lost when concatenating
multiple objects. Retains .id in .uns.
:param clean_obs: Whether to delete non-streamlined fields in .obs, .obsm and .obsp.
:param clean_var: Whether to delete non-streamlined fields in .var, .varm and .varp.
:param clean_uns: Whether to delete non-streamlined fields in .uns.
Expand Down Expand Up @@ -909,7 +910,9 @@ def streamline_metadata(
for k, v in self.adata.uns.items():
if k not in self.adata.obs_keys():
self.adata.obs[k] = [v for i in range(self.adata.n_obs)]
self.adata.uns = {}
# Retain only target uns keys in .uns.
self.adata.uns = dict([(k, v) for k, v in self.adata.uns.items()
if k in [getattr(adata_target_ids, kk) for kk in ["id"]]])

self._adata_ids = adata_target_ids # set new adata fields to class after conversion
self.streamlined_meta = True
Expand Down Expand Up @@ -948,6 +951,7 @@ def write_distributed_store(
f"data, found {type(self.adata.X)}")
fn = os.path.join(dir_cache, self.doi_cleaned_id + ".h5ad")
as_dense = ("X",) if dense else ()
print(f"writing {self.adata.shape} into {fn}")
self.adata.write_h5ad(filename=fn, as_dense=as_dense, **compression_kwargs)
elif store == "zarr":
fn = os.path.join(dir_cache, self.doi_cleaned_id)
Expand Down
252 changes: 153 additions & 99 deletions sfaira/data/base/distributed_store.py

Large diffs are not rendered by default.

304 changes: 176 additions & 128 deletions sfaira/estimators/keras.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sfaira/interface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from sfaira.interface.model_zoo import ModelZoo, ModelZooEmbedding, ModelZooCelltype
from sfaira.interface.model_zoo import ModelZoo
from sfaira.interface.user_interface import UserInterface
306 changes: 92 additions & 214 deletions sfaira/interface/model_zoo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import abc
try:
import kipoi
except ImportError:
kipoi = None
import numpy as np
import pandas as pd
from typing import List, Union

from sfaira.versions.metadata import CelltypeUniverse
from sfaira.consts import OntologyContainerSfaira
from sfaira.versions.topologies import TopologyContainer
from sfaira.versions.topologies import TopologyContainer, TOPOLOGIES


class ModelZoo(abc.ABC):
Expand All @@ -18,39 +14,95 @@ class ModelZoo(abc.ABC):
"""
topology_container: TopologyContainer
ontology: dict
model_id: Union[str, None]
model_class: Union[str, None]
model_class: Union[str, None]
model_type: Union[str, None]
model_topology: Union[str, None]
model_version: Union[str, None]
_model_id: Union[str, None]
celltypes: Union[CelltypeUniverse, None]

def __init__(
self,
model_lookuptable: Union[None, pd.DataFrame] = None
model_lookuptable: Union[None, pd.DataFrame] = None,
model_class: Union[str, None] = None,
):
"""
:param model_lookuptable: model_lookuptable.
:param model_class: Model class to subset to.
"""
self._ontology_container_sfaira = OntologyContainerSfaira()
if model_lookuptable is not None: # check if models in repository
self.ontology = self.load_ontology_from_model_ids(model_lookuptable['model_id'].values)
self.model_id = None
self.model_class = None
self.model_type = None
self.organisation = None
self.model_topology = None
self.model_version = None
self.topology_container = None
self.ontology = self.load_ontology_from_model_ids(model_ids=model_lookuptable['model_id'].values,
model_class=model_class)
self._model_id = None
self.celltypes = None

@abc.abstractmethod
@property
def model_class(self):
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[0]

@property
def model_name(self):
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[1]

@property
def model_organism(self):
# TODO: this is a custom name ontology
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[1].split("-")[0]

@property
def model_organ(self):
# TODO: this is a custom name ontology
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[1].split("-")[1]

@property
def model_type(self):
# TODO: this is a custom name ontology
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[1].split("-")[2]

@property
def model_topology(self):
# TODO: this is a custom name ontology
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[1].split("-")[3]

@property
def model_version(self):
# TODO: this is a custom name ontology
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[1].split("-")[4]

@property
def organisation(self):
assert self.model_id is not None, "set model_id first"
return self.model_id.split('_')[2]

def load_ontology_from_model_ids(
self,
model_ids
):
pass
model_ids,
model_class: Union[str, None] = None,
) -> dict:
"""
Load model ontology based on models available in model lookup tables.

:param model_ids: Table listing all available model_ids.
:param model_class: Model class to subset to.
:return: Dictionary formatted ontology.
"""

ids = [x for x in model_ids if (x.split('_')[0] == model_class or model_class is None)]
id_df = pd.DataFrame(
[i.split('_')[1:6] for i in ids],
columns=['name', 'organisation']
)
model = np.unique(id_df['name'])
ontology = dict.fromkeys(model)
for m in model:
id_df_m = id_df[id_df.model_type == m]
orga = np.unique(id_df_m['organisation'])
ontology[m] = dict.fromkeys(orga)
return ontology

def _order_versions(
self,
Expand All @@ -66,25 +118,19 @@ def _order_versions(

return versions

def set_model_id(
self,
model_id: str
):
@property
def model_id(self):
return self._model_id

@model_id.setter
def model_id(self, x: str):
"""
Set model ID to a manually supplied ID.

:param model_id: Model ID to set. Format: pipeline_genome_organ_model_organisation_topology_version
:param x: Model ID to set. Format: pipeline_genome_organ_model_organisation_topology_version
"""
if len(model_id.split('_')) < 6:
raise RuntimeError(f'Model ID {model_id} is invalid!')
self.model_id = model_id
ixs = self.model_id.split('_')
self.model_class = ixs[0]
self.model_id = ixs[1]
self.model_type = ixs[2]
self.organisation = ixs[3]
self.model_topology = ixs[4]
self.model_version = ixs[5]
assert len(x.split('_')) == 3, f'model_id {x} is invalid'
self._model_id = x

def save_weights_to_remote(self, path=None):
"""
Expand Down Expand Up @@ -113,14 +159,6 @@ def call_kipoi(self):
"""
raise NotImplementedError()

def models(self) -> List[str]:
"""
Return list of available models.

:return: List of models available.
"""
return self.ontology.keys()

def topology(
self,
model_type: str,
Expand Down Expand Up @@ -164,171 +202,11 @@ def model_hyperparameters(self) -> dict:
assert self.topology_container is not None
return self.topology_container.topology["hyper_parameters"]


class ModelZooEmbedding(ModelZoo):

"""
The supported model ontology is:

organism -> organ -> model -> organisation -> topology -> version -> ID

Maybe: include experimental protocol? Ie droplet, full-length, single-nuclei.
"""

def load_ontology_from_model_ids(
self,
model_ids
) -> dict:
"""
Load model ontology based on models available in model lookup tables.

:param model_ids: Table listing all available model_ids.
:return: Dictionary formatted ontology.
"""

ids = [i for i in model_ids if i.split('_')[0] == 'embedding']
id_df = pd.DataFrame(
[i.split('_')[1:6] for i in ids],
columns=['id', 'model_type', 'organisation', 'model_topology', 'model_version']
)
model = np.unique(id_df['model_type'])
ontology = dict.fromkeys(model)
for m in model:
id_df_m = id_df[id_df.model_type == m]
orga = np.unique(id_df_m['organisation'])
ontology[m] = dict.fromkeys(orga)
for org in orga:
id_df_org = id_df_m[id_df_m.organisation == org]
topo = np.unique(id_df_org['model_topology'])
ontology[m][org] = dict.fromkeys(topo)
for t in topo:
id_df_t = id_df_org[id_df_org.model_topology == t]
ontology[m][org][t] = id_df_t.model_version.tolist()

return ontology

def set_latest(
self,
model_type: str,
organisation: str,
model_topology: str
):
"""
Set model ID to latest model in given ontology group.

:param model_type: Identifier of model_type to select.
:param organisation: Identifier of organisation to select.
:param model_topology: Identifier of model_topology to select
:return:
"""
assert model_type in self.ontology.keys(), "model_type requested was not found in ontology"
assert organisation in self.ontology[model_type].keys(), \
"organisation requested was not found in ontology"
assert model_topology in self.ontology[model_type][organisation].keys(), \
"model_topology requested was not found in ontology"

versions = self.versions(
model_type=model_type,
organisation=organisation,
model_topology=model_topology
)
self.model_type = model_type
self.organisation = organisation
self.model_topology = model_topology # set to model for now, could be organism/organ specific later

self.model_version = self._order_versions(versions=versions)[0]
self.model_id = '_'.join([
'embedding',
self.id,
self.model_type,
self.organisation,
self.model_topology,
self.model_version
])


class ModelZooCelltype(ModelZoo):
"""
The supported model ontology is:

organism -> organ -> model -> organisation -> topology -> version -> ID

Maybe: include experimental protocol? Ie droplet, full-length, single-nuclei.

Note on topology id: The topology ID is x.y.z, x is the major cell type version and y.z is the cell type model
topology. Cell type model ontologies do not include the output size as this is set by the cell type version.
"""

def load_ontology_from_model_ids(
self,
model_ids
) -> dict:
"""
Load model ontology based on models available in model lookup tables.

:param model_ids: Table listing all available model_ids.
:return: Dictionary formatted ontology.
"""

ids = [i for i in model_ids if i.split('_')[0] == 'celltype']
id_df = pd.DataFrame(
[i.split('_')[1:6] for i in ids],
columns=['id', 'model_type', 'organisation', 'model_topology', 'model_version']
)
model = np.unique(id_df['model_type'])
ontology = dict.fromkeys(model)
for m in model:
id_df_m = id_df[id_df.model_type == m]
orga = np.unique(id_df_m['organisation'])
ontology[m] = dict.fromkeys(orga)
for org in orga:
id_df_org = id_df_m[id_df_m.organisation == org]
topo = np.unique(id_df_org['model_topology'])
ontology[m][org] = dict.fromkeys(topo)
for t in topo:
id_df_t = id_df_org[id_df_org.model_topology == t]
ontology[m][org][t] = id_df_t.model_version.tolist()

return ontology

def set_latest(
self,
model_type: str,
organisation: str,
model_topology: str
):
"""
Set model ID to latest model in given ontology group.

:param organism: Identifier of organism to select.
:param organ: Identifier of organ to select.
:param model_type: Identifier of model_type to select.
:param organisation: Identifier of organisation to select.
:param model_topology: Identifier of model_topology to select
:return:
"""
assert model_type in self.ontology.keys(), "model_type requested was not found in ontology"
assert organisation in self.ontology[model_type].keys(), \
"organisation requested was not found in ontology"
assert model_topology in self.ontology[model_type][organisation].keys(), \
"model_topology requested was not found in ontology"

versions = self.versions(
model_type=model_type,
organisation=organisation,
model_topology=model_topology
@property
def topology_container(self) -> TopologyContainer:
# TODO: this ID decomposition to organism is custom to the topologies handled in this package.
organism = self.model_name.split("-")[0]
return TopologyContainer(
topology=TOPOLOGIES[organism][self.model_class][self.model_type][self.model_topology],
topology_id=self.model_version
)

self.model_type = model_type
self.organisation = organisation
self.model_topology = model_topology # set to model for now, could be organism/organ specific later

self.model_version = self._order_versions(versions=versions)[0]
self.model_id = '_'.join([
'celltype',
self.id,
self.model_type,
self.organisation,
self.model_topology,
self.model_version
])
Loading