Skip to content

Commit

Permalink
make push/pull always use source/target (#443)
Browse files Browse the repository at this point in the history
* make push/pull always use source/target

* fix bug in StarPolicy _apply

* adapt plotting to source/target
  • Loading branch information
MUCDK authored Jan 5, 2023
1 parent f30b766 commit 2aa3734
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 96 deletions.
8 changes: 4 additions & 4 deletions src/moscot/plotting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def push(
adata=adata,
temporal_key=data["temporal_key"],
key_stored=key,
start=data["start"],
end=data["end"],
source=data["source"],
target=data["target"],
categories=data["subset"],
push=True,
time_points=time_points,
Expand Down Expand Up @@ -310,8 +310,8 @@ def pull(
adata=adata,
temporal_key=data["temporal_key"],
key_stored=key,
start=data["start"],
end=data["end"],
source=data["source"],
target=data["target"],
categories=data["subset"],
push=False,
time_points=time_points,
Expand Down
10 changes: 5 additions & 5 deletions src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def _plot_temporal(
adata: AnnData,
temporal_key: str,
key_stored: str,
start: float,
end: float,
source: float,
target: float,
categories: Optional[Union[str, List[str]]] = None,
*,
push: bool,
Expand Down Expand Up @@ -402,10 +402,10 @@ def _plot_temporal(
else:
name = "descendants" if push else "ancestors"
if time_points is not None:
titles = [f"{categories} at time {start if push else end}"]
titles = [f"{categories} at time {source if push else target}"]
titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))])
else:
titles = [f"{categories} at time {start if push else end} and {name}"]
titles = [f"{categories} at time {source if push else target} and {name}"]
for i, ax in enumerate(axs):
with RandomKeys(adata, n=1, where="obs") as keys:
if time_points is None:
Expand All @@ -432,7 +432,7 @@ def _plot_temporal(

_ = kwargs.pop("color_map", None)
_ = kwargs.pop("palette", None)
if (time_points[i] == start and push) or (time_points[i] == end and not push):
if (time_points[i] == source and push) or (time_points[i] == target and not push):
st = f"not in {time_points[i]}"
vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask])
column = pd.Series(tmp).fillna(st).astype("category")
Expand Down
18 changes: 9 additions & 9 deletions src/moscot/problems/base/_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _(
data: Optional[Union[str, ArrayLike]] = None,
forward: bool = True,
scale_by_marginals: bool = False,
start: Optional[K] = None,
source: Optional[K] = None,
return_all: bool = True,
**kwargs: Any,
) -> ApplyOutput_t[K]:
Expand All @@ -301,11 +301,11 @@ def _(
res = {}
# TODO(michalk8): should use manager.plan (once implemented), as some problems may not be solved
# TODO: better check
start = start if isinstance(start, list) else [start]
_ = kwargs.pop("end", None) # make compatible with Explicit/Ordered policy
source = source if isinstance(source, list) else [source]
_ = kwargs.pop("target", None) # make compatible with Explicit/Ordered policy
for src, tgt in self._policy.plan(
explicit_steps=kwargs.pop("explicit_steps", None),
filter=start, # type: ignore [arg-type]
filter=source, # type: ignore [arg-type]
):
problem = self.problems[src, tgt]
fun = problem.push if forward else problem.pull
Expand All @@ -319,20 +319,20 @@ def _(
data: Optional[Union[str, ArrayLike]] = None,
forward: bool = True,
scale_by_marginals: bool = False,
start: Optional[K] = None,
end: Optional[K] = None,
source: Optional[K] = None,
target: Optional[K] = None,
return_all: bool = False,
**kwargs: Any,
) -> ApplyOutput_t[K]:
explicit_steps = kwargs.pop(
"explicit_steps", [[start, end]] if isinstance(self._policy, ExplicitPolicy) else None
"explicit_steps", [[source, target]] if isinstance(self._policy, ExplicitPolicy) else None
)
if TYPE_CHECKING:
assert isinstance(self._policy, OrderedPolicy)
(src, tgt), *rest = self._policy.plan(
forward=forward,
start=start,
end=end,
start=source,
end=target,
explicit_steps=explicit_steps,
)
problem = self.problems[src, tgt]
Expand Down
24 changes: 12 additions & 12 deletions src/moscot/problems/base/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class AnalysisMixinProtocol(Protocol[K, B]):
def _apply(
self,
data: Optional[Union[str, ArrayLike]] = None,
start: Optional[K] = None,
end: Optional[K] = None,
source: Optional[K] = None,
target: Optional[K] = None,
forward: bool = True,
return_all: bool = False,
scale_by_marginals: bool = False,
Expand Down Expand Up @@ -304,8 +304,8 @@ def _sample_from_tmap(
mass = np.ones(target_dim)
if account_for_unbalancedness and interpolation_parameter is not None:
col_sums = self._apply(
start=source,
end=target,
source=source,
target=target,
normalize=True,
forward=True,
scale_by_marginals=False,
Expand All @@ -318,8 +318,8 @@ def _sample_from_tmap(

row_probability = np.asarray(
self._apply(
start=source,
end=target,
source=source,
target=target,
data=mass,
normalize=True,
forward=False,
Expand All @@ -339,8 +339,8 @@ def _sample_from_tmap(

col_p_given_row = np.asarray(
self._apply(
start=source,
end=target,
source=source,
target=target,
data=data,
normalize=True,
forward=True,
Expand Down Expand Up @@ -433,8 +433,8 @@ def _annotation_aggregation_transition(
func = self.push if forward else self.pull
for subset in annotations_1:
result = func( # TODO(@MUCDK) check how to make compatible with all policies
start=source,
end=target,
source=source,
target=target,
data=annotation_key,
subset=subset,
normalize=True,
Expand Down Expand Up @@ -472,8 +472,8 @@ def _cell_aggregation_transition(
batch_size = len(df_2)
for batch in range(0, len(df_2), batch_size):
result = func( # TODO(@MUCDK) check how to make compatible with all policies
start=source,
end=target,
source=source,
target=target,
data=None,
subset=(batch, batch_size),
normalize=True,
Expand Down
12 changes: 6 additions & 6 deletions src/moscot/problems/generic/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def push(
Parameters
----------
%(source)s
%(target)s
%(start)s
%(end)s
%(data)s
%(subset)s
%(scale_by_marginals)s
Expand All @@ -122,8 +122,8 @@ def push(
"""
result = self._apply(
start=source,
end=target,
source=source,
target=target,
data=data,
subset=subset,
forward=True,
Expand Down Expand Up @@ -179,8 +179,8 @@ def pull(
"""
result = self._apply(
start=source,
end=target,
source=source,
target=target,
data=data,
subset=subset,
forward=False,
Expand Down
Loading

0 comments on commit 2aa3734

Please sign in to comment.