Skip to content

Commit

Permalink
Fix UI (#378)
Browse files Browse the repository at this point in the history
* add unit tests for UI

* fix ui

* store sfaira_repo_url in sfaira.consts

* remove capitalisation of gene symbols when streamlining features

* extend unit test to check that gene symbols match the genome container after feature streamlining

* added check that genome assemblies are the same between celltype and embedding zoo

* making sure var_names are automatically set according to schema in ui

* keep original obs_names when loading a dataset through the ui
  • Loading branch information
le-ander authored Sep 28, 2021
1 parent c26c0db commit 6d0eb83
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 19 deletions.
2 changes: 1 addition & 1 deletion sfaira/consts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sfaira.consts.adata_fields import AdataIds, AdataIdsSfaira, AdataIdsCellxgene, AdataIdsCellxgeneGeneral, \
AdataIdsCellxgeneHuman_v1_1_0, AdataIdsCellxgeneMouse_v1_1_0
from sfaira.consts.directories import CACHE_DIR
from sfaira.consts.directories import CACHE_DIR, SFAIRA_REPO_URL
from sfaira.consts.meta_data_files import META_DATA_FIELDS
from sfaira.consts.ontologies import OntologyContainerSfaira
from sfaira.consts.utils import clean_cache
Expand Down
2 changes: 2 additions & 0 deletions sfaira/consts/directories.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@
CACHE_DIR_GENOMES = os.path.join(CACHE_DIR, "genomes")

CACHE_DIR_ONTOLOGIES = os.path.join(CACHE_DIR, "ontologies")

SFAIRA_REPO_URL = "https://zenodo.org/record/4836517/files/"
2 changes: 1 addition & 1 deletion sfaira/data/dataloaders/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def streamline_features(
if y in subset_genes_to_type
]
subset_ids_symbol = [
x.upper() for x, y in zip(self.genome_container.symbols, self.genome_container.biotype)
x for x, y in zip(self.genome_container.symbols, self.genome_container.biotype)
if y in subset_genes_to_type
]

Expand Down
32 changes: 29 additions & 3 deletions sfaira/ui/user_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
import time

from sfaira.consts import AdataIdsSfaira, AdataIds, OCS
from sfaira.consts import AdataIdsSfaira, AdataIds, OCS, SFAIRA_REPO_URL
from sfaira.data import DatasetInteractive
from sfaira.estimators import EstimatorKerasEmbedding, EstimatorKerasCelltype
from sfaira.ui.model_zoo import ModelZoo
Expand Down Expand Up @@ -61,14 +61,17 @@ def __init__(
self.adata_ids = AdataIdsSfaira()

if sfaira_repo: # check if public sfaira repository should be accessed
self.model_lookuptable = self._load_lookuptable("https://zenodo.org/record/4836517/files/")
self.model_lookuptable = self._load_lookuptable(SFAIRA_REPO_URL)

if custom_repo:
if isinstance(custom_repo, str):
custom_repo = [custom_repo]

for repo in custom_repo:
if os.path.exists(repo) and not os.path.exists(os.path.join(repo, 'model_lookuptable.csv')):
if not os.path.exists(repo):
raise OSError(f"provided repo directory does not exist, please create it first: {repo}")

if not os.path.exists(os.path.join(repo, 'model_lookuptable.csv')):
self.write_lookuptable(repo)

if hasattr(self, 'model_lookuptable'):
Expand Down Expand Up @@ -391,6 +394,13 @@ def load_data(
match_to_reference=self.zoo_embedding.topology_container.gc.assembly,
subset_genes_to_type=list(set(self.zoo_embedding.topology_container.gc.biotype))
)
# Transfer required metadata from the Dataset instance to the adata object
self.data.streamline_metadata(
clean_obs=False,
clean_var=True,
clean_uns=False,
clean_obs_names=False,
)

def _load_topology_dict(self, model_weights_file) -> dict:
topology_filepath = ".".join(model_weights_file.split(".")[:-1])
Expand Down Expand Up @@ -423,6 +433,14 @@ def load_model_embedding(self):
:return: Model ID loaded.
"""
assert self.zoo_embedding.model_id is not None, "choose embedding model first"
if self.zoo_celltype.topology_container.gc.assembly is not None:
assert self.zoo_embedding.topology_container.gc.assembly == \
self.zoo_celltype.topology_container.gc.assembly, f"genome assemblies defined in the topology " \
f"containers if the embedding and the celltype " \
f"prediction model are not equivalent " \
f"({self.zoo_embedding.topology_container.gc.assembly} " \
f"and {self.zoo_celltype.topology_container.gc.assembly} " \
f"respectively, aborting.)"
model_weights_file = self.model_lookuptable["model_file_path"].loc[self.model_lookuptable["model_id"] ==
self.zoo_embedding.model_id].iloc[0]
md5 = self.model_lookuptable["md5"].loc[self.model_lookuptable["model_id"] ==
Expand Down Expand Up @@ -452,6 +470,14 @@ def load_model_celltype(self):
:return: Model ID loaded.
"""
assert self.zoo_celltype.model_id is not None, "choose cell type model first"
if self.zoo_embedding.topology_container.gc.assembly is not None:
assert self.zoo_embedding.topology_container.gc.assembly == \
self.zoo_celltype.topology_container.gc.assembly, f"genome assemblies defined in the topology " \
f"containers if the embedding and the celltype " \
f"prediction model are not equivalent " \
f"({self.zoo_embedding.topology_container.gc.assembly} " \
f"and {self.zoo_celltype.topology_container.gc.assembly} " \
f"respectively, aborting.)"
model_weights_file = self.model_lookuptable["model_file_path"].loc[self.model_lookuptable["model_id"] ==
self.zoo_celltype.model_id].iloc[0]
md5 = self.model_lookuptable["md5"].loc[self.model_lookuptable["model_id"] ==
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def test_dsgs_streamline_features(match_to_reference: str, remove_gene_version:
ds.load()
ds.streamline_features(remove_gene_version=remove_gene_version, match_to_reference=match_to_reference,
subset_genes_to_type=subset_genes_to_type)
gc = ds.get_gc(match_to_reference["mouse"] if isinstance(match_to_reference, dict) else match_to_reference)
gc.subset(biotype=subset_genes_to_type)
assert ds.adata.var["gene_symbol"].tolist() == gc.symbols


def test_dsg_load():
Expand Down
64 changes: 50 additions & 14 deletions sfaira/unit_tests/tests_by_submodule/ui/test_userinterface.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
import os
from typing import Union
import pandas as pd
import urllib.request

from sfaira.ui import UserInterface
from sfaira.unit_tests.data_for_tests.loaders.utils import prepare_dsg
from sfaira.unit_tests import DIR_TEMP
from sfaira.consts import SFAIRA_REPO_URL


class HelperUi:
Expand All @@ -17,30 +21,62 @@ class HelperUi:
basic_estimator_test(). See _test_call() for an example.
"""

def simulate(self):
"""
Simulate basic data example used for unit test.
def __init__(self):
self.temp_fn = os.path.join(DIR_TEMP, "test_data")

def prepare_local_tempfiles(self):
# create temp directory
if not os.path.exists(self.temp_fn):
os.makedirs(self.temp_fn)
# download an example weight from sfaira repo
lookuptable = pd.read_csv(os.path.join(SFAIRA_REPO_URL, 'model_lookuptable.csv'), header=0, index_col=0)
url = lookuptable.loc[0, "model_file_path"]
if os.path.basename(url) not in os.listdir(self.temp_fn):
urllib.request.urlretrieve(url, os.path.join(self.temp_fn, os.path.basename(url)))

Sets attribute .data with simulated data.
def _get_adata(self):
"""
Create an adata object for use in unit tests
:return:
"""
pass
dsg = prepare_dsg(rewrite=True, load=False)
dsg.subset(key="id", values=["human_lung_2021_None_mock4_001_no_doi_mock4"])
dsg.load()
return dsg.adata

def test_local_repo_ui_init(self):
"""
Test if the sfaira UI can be sucessfully initialised using a local model repository
def test_basic(self):
:return:
"""
Test all relevant model methods.
self.ui = UserInterface(custom_repo=self.temp_fn, sfaira_repo=False)

def test_public_repo_ui_init(self):
"""
Test if the sfaira UI can be sucessfully initialised using the public sfaira model repository
:return:
"""
temp_fn = os.path.join(DIR_TEMP, "test_data")
self.ui = UserInterface(custom_repo=temp_fn, sfaira_repo=False)
self.ui = UserInterface(custom_repo=None, sfaira_repo=True)

def test_data_and_model_loading(self):
self.ui = UserInterface(custom_repo=None, sfaira_repo=True)
self.ui.zoo_embedding.model_id = 'embedding_human-blood-ae-0.2-0.1_theislab'
self.ui.zoo_celltype.model_id = 'celltype_human-blood-mlp-0.1.3-0.1_theislab'
test_data = self._get_adata()
self.ui.load_data(test_data, gene_ens_col='index')
self.ui.load_model_celltype()
self.ui.load_model_embedding()
self.ui.predict_all()
assert "X_sfaira" in self.ui.data.adata.obsm_keys()
assert "celltypes_sfaira" in self.ui.data.adata.obs_keys()

def _test_for_fatal():
"""
TODO need to simulate/add look up table as part of unit tests locally
"""

def test_ui():
ui = HelperUi()
ui.test_basic()
ui.prepare_local_tempfiles()
ui.test_public_repo_ui_init()
ui.test_local_repo_ui_init()
ui.test_data_and_model_loading()

0 comments on commit 6d0eb83

Please sign in to comment.