Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.3.6 #365

Merged
merged 19 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions sfaira/data/dataloaders/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,17 +1120,12 @@ def read_ontology_class_map(self, fn):
if self.cell_type_obs_key is not None:
warnings.warn(f"file {fn} does not exist but cell_type_obs_key {self.cell_type_obs_key} is given")

def project_free_to_ontology(self, attr: str, copy: bool = False):
def project_free_to_ontology(self, attr: str):
"""
Project free text cell type names to ontology based on mapping table.

ToDo: add ontology ID setting here.
ToDo: only for cell type right now, extend to other meta data in the future.

:param copy: If True, a dataframe with the celltype annotation is returned, otherwise self.adata.obs is updated
inplace.

:return:
"""
ontology_map = attr + "_map"
if hasattr(self, ontology_map):
Expand All @@ -1139,7 +1134,6 @@ def project_free_to_ontology(self, attr: str, copy: bool = False):
ontology_map = None
print(f"WARNING: did not find ontology map for {attr} which was only defined by free annotation")
adata_fields = self._adata_ids
results = {}
col_original = attr + adata_fields.onto_original_suffix
labels_original = self.adata.obs[col_original].values
if ontology_map is not None: # only if this was defined
Expand Down Expand Up @@ -1173,19 +1167,17 @@ def project_free_to_ontology(self, attr: str, copy: bool = False):
# TODO this could be changed in the future, this allows this function to be used both on cell type name
# mapping files with and without the ID in the third column.
# This mapping blocks progression in the unit test if not deactivated.
results[getattr(adata_fields, attr)] = labels_mapped
self.adata.obs[getattr(adata_fields, attr)] = labels_mapped
self.__project_ontology_ids_obs(attr=attr, map_exceptions=map_exceptions, from_id=False,
adata_ids=adata_fields)
else:
results[getattr(adata_fields, attr)] = labels_original
results[getattr(adata_fields, attr) + adata_fields.onto_id_suffix] = \
# Assumes that the original labels are the correct ontology symbols, because of a lack of ontology,
# ontology IDs cannot be inferred.
# TODO is this necessary in the future?
self.adata.obs[getattr(adata_fields, attr)] = labels_original
self.adata.obs[getattr(adata_fields, attr) + adata_fields.onto_id_suffix] = \
[adata_fields.unknown_metadata_identifier] * self.adata.n_obs
results[getattr(adata_fields, attr) + adata_fields.onto_original_suffix] = labels_original
if copy:
return pd.DataFrame(results, index=self.adata.obs.index)
else:
for k, v in results.items():
self.adata.obs[k] = v
self.adata.obs[getattr(adata_fields, attr) + adata_fields.onto_original_suffix] = labels_original

def __impute_ontology_cols_obs(
self,
Expand Down Expand Up @@ -1238,7 +1230,7 @@ def __impute_ontology_cols_obs(
# Original annotation (free text):
original_present = col_original in self.adata.obs.columns
if original_present and not symbol_present and not id_present: # 1)
self.project_free_to_ontology(attr=attr, copy=False)
self.project_free_to_ontology(attr=attr)
if symbol_present or id_present: # 2)
if symbol_present and not id_present: # 2a)
self.__project_ontology_ids_obs(attr=attr, from_id=False, adata_ids=adata_ids)
Expand Down
4 changes: 2 additions & 2 deletions sfaira/data/dataloaders/base/dataset_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def ontology_celltypes(self):
# ToDo: think about whether this should be handled differently.
warnings.warn("found more than one organism in group, this could cause problems with using a joined cell "
"type ontology. Using only the ontology of the first data set in the group.")
return self.datasets[self.ids[0]].ontology_celltypes
return self.datasets[self.ids[0]].ontology_container_sfaira.cell_type

def project_celltypes_to_ontology(self, adata_fields: Union[AdataIds, None] = None, copy=False):
"""
Expand Down Expand Up @@ -678,7 +678,7 @@ def __init__(
elif package_source == "sfaira_extension":
package_source = "sfairae"
else:
raise ValueError(f"invalid package source {package_source} for {self._cwd}, {self.collection_id}")
raise ValueError(f"invalid package source {package_source} for {self._cwd}")
except IndexError as e:
raise IndexError(f"{e} for {self._cwd}")
loader_pydoc_path_sfaira = "sfaira.data.dataloaders.loaders."
Expand Down
15 changes: 9 additions & 6 deletions sfaira/data/dataloaders/loaders/super_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@ def __init__(
if f[:len(dir_prefix)] == dir_prefix and f not in dir_exclude: # Narrow down to data set directories
path_dsg = str(pydoc.locate(f"sfaira.data.dataloaders.loaders.{f}.FILE_PATH"))
if path_dsg is not None:
dataset_groups.append(DatasetGroupDirectoryOriented(
file_base=path_dsg,
data_path=data_path,
meta_path=meta_path,
cache_path=cache_path
))
try:
dataset_groups.append(DatasetGroupDirectoryOriented(
file_base=path_dsg,
data_path=data_path,
meta_path=meta_path,
cache_path=cache_path
))
except IndexError as e:
raise IndexError(f"{e} for '{cwd}', '{f}', '{path_dsg}'")
else:
warn(f"DatasetGroupDirectoryOriented was None for {f}")
super().__init__(dataset_groups=dataset_groups)
18 changes: 10 additions & 8 deletions sfaira/data/store/batch_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BatchDesignBase:

def __init__(self, retrieval_batch_size: int, randomized_batch_access: bool, random_access: bool, **kwargs):
self.retrieval_batch_size = retrieval_batch_size
self._batch_bounds = None
self._idx = None
if randomized_batch_access and random_access:
raise ValueError("Do not use randomized_batch_access and random_access.")
Expand Down Expand Up @@ -51,15 +52,16 @@ def idx(self, x):
self._idx = np.sort(x) # Sorted indices improve accession efficiency in some cases.

@property
def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
def design(self) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]:
"""
Yields index objects for one epoch of all data.

These index objects are used by generators that have access to the data objects to build data batches.
Randomization is performed anew with every call to this property.

:returns: Tuple of:
- Ordering of observations in epoch.
- Indices of observations in epoch out of selected data set (ie indices of self.idx).
- Ordering of observations in epoch out of full data set.
- Batch start and end indices for batch based on ordering defined in first output.
"""
raise NotImplementedError()
Expand All @@ -68,14 +70,14 @@ def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
class BatchDesignBasic(BatchDesignBase):

@property
def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
idx_proc = self.idx.copy()
def design(self) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]:
idx_proc = np.arange(0, len(self.idx))
if self.random_access:
np.random.shuffle(idx_proc)
batch_bounds = self.batch_bounds.copy()
if self.randomized_batch_access:
batch_bounds = _randomize_batch_start_ends(batch_starts_ends=batch_bounds)
return idx_proc, batch_bounds
return idx_proc, self.idx[idx_proc], batch_bounds


class BatchDesignBalanced(BatchDesignBase):
Expand Down Expand Up @@ -110,15 +112,15 @@ def __init__(self, grouping, group_weights: dict, randomized_batch_access: bool,
self.p_obs = p_obs

@property
def design(self) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
def design(self) -> Tuple[np.ndarray, np.ndarray, List[Tuple[int, int]]]:
# Re-sample index vector.
idx_proc = np.random.choice(a=self.idx, replace=True, size=len(self.idx), p=self.p_obs)
idx_proc = np.random.choice(a=np.arange(0, len(self.idx)), replace=True, size=len(self.idx), p=self.p_obs)
if not self.random_access: # Note: randomization is result from sampling above, need to revert if not desired.
idx_proc = np.sort(idx_proc)
batch_bounds = self.batch_bounds.copy()
if self.randomized_batch_access:
batch_bounds = _randomize_batch_start_ends(batch_starts_ends=batch_bounds)
return idx_proc, batch_bounds
return idx_proc, self.idx[idx_proc], batch_bounds


BATCH_SCHEDULE = {
Expand Down
108 changes: 66 additions & 42 deletions sfaira/data/store/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
from sfaira.data.store.batch_schedule import BATCH_SCHEDULE


def split_batch(x, obs):
def split_batch(x):
"""
Splits retrieval batch into consumption batches of length 1.

Often, end-user consumption batches would be observation-wise, ie yield a first dimension of length 1.

:param x: Data tuple of length 1 or 2: (input,) or (input, output,), where both input and output are also
a tuple, but of batch-dimensioned tensors.
"""
for i in range(x.shape[0]):
yield x[i, :], obs.iloc[[i], :]
batch_dim = x[0][0].shape[0]
for i in range(batch_dim):
output = []
for y in x:
if isinstance(y, tuple):
output.append(tuple([z[i, :] for z in y]))
else:
output.append(y[i, :])
yield tuple(output)


class GeneratorBase:
Expand Down Expand Up @@ -92,6 +102,7 @@ def __init__(self, batch_schedule, batch_size, map_fn, obs_idx, obs_keys, var_id
in self.adata_by_key.
:param var_idx: The features to emit.
"""
self.var_idx = var_idx
self._obs_idx = None
if not batch_size == 1:
raise ValueError(f"Only batch size==1 is supported, found {batch_size}.")
Expand All @@ -103,7 +114,6 @@ def __init__(self, batch_schedule, batch_size, map_fn, obs_idx, obs_keys, var_id
self.schedule = batch_schedule(**kwargs)
self.obs_idx = obs_idx
self.obs_keys = obs_keys
self.var_idx = var_idx

def _validate_idx(self, idx: Union[np.ndarray, list]) -> np.ndarray:
"""
Expand Down Expand Up @@ -170,7 +180,7 @@ def iterator(self) -> iter:
# Speed up access to single object by skipping index overlap operations:

def g():
obs_idx, batch_bounds = self.schedule.design
_, obs_idx, batch_bounds = self.schedule.design
for s, e in batch_bounds:
idx_i = obs_idx[s:e]
# Match adata objects that overlap to batch:
Expand Down Expand Up @@ -204,13 +214,9 @@ def g():
x = x[:, self.var_idx]
# Prepare .obs.
obs = self.adata_dict[k].obs[self.obs_keys].iloc[v, :]
for x_i, obs_i in split_batch(x=x, obs=obs):
if self.map_fn is None:
yield x_i, obs_i
else:
output = self.map_fn(x_i, obs_i)
if output is not None:
yield output
data_tuple = self.map_fn(x, obs)
for data_tuple_i in split_batch(x=data_tuple):
yield data_tuple_i
else:
# Concatenates slices first before returning. Note that this is likely slower than emitting by
# observation in most scenarios.
Expand Down Expand Up @@ -250,64 +256,82 @@ def g():
self.adata_dict[k].obs[self.obs_keys].iloc[v, :]
for k, v in idx_i_dict.items()
], axis=0, join="inner", ignore_index=True, copy=False)
if self.map_fn is None:
yield x, obs
else:
output = self.map_fn(x, obs)
if output is not None:
yield output
data_tuple = self.map_fn(x, obs)
yield data_tuple

return g


class GeneratorDask(GeneratorSingle):

"""
In addition to the full data array, x, this class maintains a slice _x_slice which is indexed by the iterator.
Access to the slice can be optimised with dask and is therefore desirable.
"""

x: dask.array
_x_slice: dask.array
obs: pd.DataFrame

def __init__(self, x, obs, **kwargs):
def __init__(self, x, obs, obs_keys, var_idx, **kwargs):
if var_idx is not None:
x = x[:, var_idx]
self.x = x
super(GeneratorDask, self).__init__(**kwargs)
self.obs = obs[self.obs_keys]
self.obs = obs[obs_keys]
# Redefine index so that .loc indexing can be used instead of .iloc indexing:
self.obs.index = np.arange(0, obs.shape[0])
self._x_slice = None
self._obs_slice = None
super(GeneratorDask, self).__init__(obs_keys=obs_keys, var_idx=var_idx, **kwargs)

@property
def n_obs(self) -> int:
return self.x.shape[0]

@property
def obs_idx(self):
return self._obs_idx

@obs_idx.setter
def obs_idx(self, x):
"""
Allows emission of different iterator on same generator instance (using same dask array).
In addition to base method: allows for optimisation of dask array for batch draws.
"""
if x is None:
x = np.arange(0, self.n_obs)
else:
x = self._validate_idx(x)
x = np.sort(x)
# Only reset if they are actually different:
if (self._obs_idx is not None and len(x) != len(self._obs_idx)) or np.any(x != self._obs_idx):
self._obs_idx = x
self.schedule.idx = x
self._x_slice = dask.optimize(self.x[self._obs_idx, :])[0]
self._obs_slice = self.obs.loc[self.obs.index[self._obs_idx], :] # TODO better than iloc?
# Redefine index so that .loc indexing can be used instead of .iloc indexing:
self._obs_slice.index = np.arange(0, self._obs_slice.shape[0])

@property
def iterator(self) -> iter:
# Can all data sets corresponding to one organism as a single array because they share the second dimension
# and dask keeps expression data and obs out of memory.

def g():
obs_idx, batch_bounds = self.schedule.design
x_temp = self.x[obs_idx, :]
obs_temp = self.obs.loc[self.obs.index[obs_idx], :] # TODO better than iloc?
obs_idx_slice, _, batch_bounds = self.schedule.design
x_temp = self._x_slice
obs_temp = self._obs_slice
for s, e in batch_bounds:
x_i = x_temp[s:e, :]
if self.var_idx is not None:
x_i = x_i[:, self.var_idx]
x_i = x_temp[obs_idx_slice[s:e], :]
# Exploit fact that index of obs is just increasing list of integers, so we can use the .loc[]
# indexing instead of .iloc[]:
obs_i = obs_temp.loc[obs_temp.index[s:e], :]
# TODO place map_fn outside of for loop so that vectorisation in preprocessing can be used.
obs_i = obs_temp.loc[obs_temp.index[obs_idx_slice[s:e]], :]
data_tuple = self.map_fn(x_i, obs_i)
if self.batch_size == 1:
for x_ii, obs_ii in split_batch(x=x_i, obs=obs_i):
if self.map_fn is None:
yield x_ii, obs_ii
else:
output = self.map_fn(x_ii, obs_ii)
if output is not None:
yield output
for data_tuple_i in split_batch(x=data_tuple):
yield data_tuple_i
else:
if self.map_fn is None:
yield x_i, obs_i
else:
output = self.map_fn(x_i, obs_i)
if output is not None:
yield output
yield data_tuple

return g

Expand Down
9 changes: 7 additions & 2 deletions sfaira/data/store/single_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,15 +583,20 @@ def X_slice(self, idx: np.ndarray, as_sparse: bool = True, **kwargs) -> Union[np
:return: Slice of data array.
"""
batch_size = min(len(idx), 128)

def map_fn(x, obs):
return (x, ),

g = self.generator(idx=idx, retrieval_batch_size=batch_size, return_dense=True, random_access=False,
randomized_batch_access=False, **kwargs)
randomized_batch_access=False, map_fn=map_fn, **kwargs)
shape = (idx.shape[0], self.n_vars)
if as_sparse:
x = scipy.sparse.csr_matrix(np.zeros(shape))
else:
x = np.empty(shape)
counter = 0
for x_batch, _ in g.iterator():
for x_batch, in g.iterator():
x_batch = x_batch[0]
batch_len = x_batch.shape[0]
x[counter:(counter + batch_len), :] = x_batch
counter += batch_len
Expand Down
26 changes: 26 additions & 0 deletions sfaira/data/utils_scripts/survey_obs_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import sfaira
import sys

# Set global variables.
print("sys.argv", sys.argv)

data_path = str(sys.argv[1])
path_meta = str(sys.argv[2])
path_cache = str(sys.argv[3])

universe = sfaira.data.dataloaders.Universe(
data_path=data_path, meta_path=path_meta, cache_path=path_cache
)
for k, v in universe.datasets.items():
print(k)
v.load(
load_raw=False,
allow_caching=True,
)
for col in v.adata.obs.columns:
val = np.sort(np.unique(v.adata.obs[col].values))
if len(val) > 20:
val = val[:20]
print(f"{k}: {col}: {val}")
v.clear()
Loading