diff --git a/src/moscot/plotting/_plotting.py b/src/moscot/plotting/_plotting.py index 1eb36b2d7..e34b67f9e 100644 --- a/src/moscot/plotting/_plotting.py +++ b/src/moscot/plotting/_plotting.py @@ -225,8 +225,8 @@ def push( adata=adata, temporal_key=data["temporal_key"], key_stored=key, - start=data["start"], - end=data["end"], + source=data["source"], + target=data["target"], categories=data["subset"], push=True, time_points=time_points, @@ -310,8 +310,8 @@ def pull( adata=adata, temporal_key=data["temporal_key"], key_stored=key, - start=data["start"], - end=data["end"], + source=data["source"], + target=data["target"], categories=data["subset"], push=False, time_points=time_points, diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index ed6eb917e..950a5d781 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -359,8 +359,8 @@ def _plot_temporal( adata: AnnData, temporal_key: str, key_stored: str, - start: float, - end: float, + source: float, + target: float, categories: Optional[Union[str, List[str]]] = None, *, push: bool, @@ -402,10 +402,10 @@ def _plot_temporal( else: name = "descendants" if push else "ancestors" if time_points is not None: - titles = [f"{categories} at time {start if push else end}"] + titles = [f"{categories} at time {source if push else target}"] titles.extend([f"{name} at time {time_points[i]}" for i in range(1, len(time_points))]) else: - titles = [f"{categories} at time {start if push else end} and {name}"] + titles = [f"{categories} at time {source if push else target} and {name}"] for i, ax in enumerate(axs): with RandomKeys(adata, n=1, where="obs") as keys: if time_points is None: @@ -432,7 +432,7 @@ def _plot_temporal( _ = kwargs.pop("color_map", None) _ = kwargs.pop("palette", None) - if (time_points[i] == start and push) or (time_points[i] == end and not push): + if (time_points[i] == source and push) or (time_points[i] == target and not push): st = f"not in {time_points[i]}" vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask]) column = pd.Series(tmp).fillna(st).astype("category") diff --git a/src/moscot/problems/base/_compound_problem.py b/src/moscot/problems/base/_compound_problem.py index 8b8ca61ca..5b11a53e6 100644 --- a/src/moscot/problems/base/_compound_problem.py +++ b/src/moscot/problems/base/_compound_problem.py @@ -292,7 +292,7 @@ def _( data: Optional[Union[str, ArrayLike]] = None, forward: bool = True, scale_by_marginals: bool = False, - start: Optional[K] = None, + source: Optional[K] = None, return_all: bool = True, **kwargs: Any, ) -> ApplyOutput_t[K]: @@ -301,11 +301,11 @@ def _( res = {} # TODO(michalk8): should use manager.plan (once implemented), as some problems may not be solved # TODO: better check - start = start if isinstance(start, list) else [start] - _ = kwargs.pop("end", None) # make compatible with Explicit/Ordered policy + source = source if isinstance(source, list) else [source] + _ = kwargs.pop("target", None) # make compatible with Explicit/Ordered policy for src, tgt in self._policy.plan( explicit_steps=kwargs.pop("explicit_steps", None), - filter=start, # type: ignore [arg-type] + filter=source, # type: ignore [arg-type] ): problem = self.problems[src, tgt] fun = problem.push if forward else problem.pull @@ -319,20 +319,20 @@ def _( data: Optional[Union[str, ArrayLike]] = None, forward: bool = True, scale_by_marginals: bool = False, - start: Optional[K] = None, - end: Optional[K] = None, + source: Optional[K] = None, + target: Optional[K] = None, return_all: bool = False, **kwargs: Any, ) -> ApplyOutput_t[K]: explicit_steps = kwargs.pop( - "explicit_steps", [[start, end]] if isinstance(self._policy, ExplicitPolicy) else None + "explicit_steps", [[source, target]] if isinstance(self._policy, ExplicitPolicy) else None ) if TYPE_CHECKING: assert isinstance(self._policy, OrderedPolicy) (src, tgt), *rest = self._policy.plan( forward=forward, - start=start, - end=end, + start=source, + end=target, explicit_steps=explicit_steps, ) problem = self.problems[src, tgt] diff --git a/src/moscot/problems/base/_mixins.py b/src/moscot/problems/base/_mixins.py index 147afbae7..8b5b3b7d8 100644 --- a/src/moscot/problems/base/_mixins.py +++ b/src/moscot/problems/base/_mixins.py @@ -49,8 +49,8 @@ class AnalysisMixinProtocol(Protocol[K, B]): def _apply( self, data: Optional[Union[str, ArrayLike]] = None, - start: Optional[K] = None, - end: Optional[K] = None, + source: Optional[K] = None, + target: Optional[K] = None, forward: bool = True, return_all: bool = False, scale_by_marginals: bool = False, @@ -304,8 +304,8 @@ def _sample_from_tmap( mass = np.ones(target_dim) if account_for_unbalancedness and interpolation_parameter is not None: col_sums = self._apply( - start=source, - end=target, + source=source, + target=target, normalize=True, forward=True, scale_by_marginals=False, @@ -318,8 +318,8 @@ def _sample_from_tmap( row_probability = np.asarray( self._apply( - start=source, - end=target, + source=source, + target=target, data=mass, normalize=True, forward=False, @@ -339,8 +339,8 @@ def _sample_from_tmap( col_p_given_row = np.asarray( self._apply( - start=source, - end=target, + source=source, + target=target, data=data, normalize=True, forward=True, @@ -433,8 +433,8 @@ def _annotation_aggregation_transition( func = self.push if forward else self.pull for subset in annotations_1: result = func( # TODO(@MUCDK) check how to make compatible with all policies - start=source, - end=target, + source=source, + target=target, data=annotation_key, subset=subset, normalize=True, @@ -472,8 +472,8 @@ def _cell_aggregation_transition( batch_size = len(df_2) for batch in range(0, len(df_2), batch_size): result = func( # TODO(@MUCDK) check how to make compatible with all policies - start=source, - end=target, + source=source, + target=target, data=None, subset=(batch, batch_size), normalize=True, diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index 4513272b7..e28a66d86 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -107,8 +107,8 @@ def push( Parameters ---------- - %(source)s - %(target)s + %(start)s + %(end)s %(data)s %(subset)s %(scale_by_marginals)s @@ -122,8 +122,8 @@ def push( """ result = self._apply( - start=source, - end=target, + source=source, + target=target, data=data, subset=subset, forward=True, @@ -179,8 +179,8 @@ def pull( """ result = self._apply( - start=source, - end=target, + source=source, + target=target, data=data, subset=subset, forward=False, diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 811466b5e..66f73c6b9 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -82,8 +82,8 @@ def _interpolate_gex_with_ot( number_cells: int, source_data: ArrayLike, target_data: ArrayLike, - start: K, - end: K, + source: K, + target: K, interpolation_parameter: float, account_for_unbalancedness: bool = True, batch_size: int = 256, @@ -93,9 +93,9 @@ def _interpolate_gex_with_ot( def _get_data( self: "TemporalMixinProtocol[K, B]", - start: K, + source: K, intermediate: Optional[K] = None, - end: Optional[K] = None, + target: Optional[K] = None, posterior_marginals: bool = True, *, only_start: bool = False, @@ -116,8 +116,8 @@ def _interpolate_gex_randomly( def _plot_temporal( self: "TemporalMixinProtocol[K, B]", data: Dict[K, ArrayLike], - start: K, - end: K, + source: K, + target: K, time_points: Optional[Iterable[K]] = None, basis: str = "umap", result_key: Optional[str] = None, @@ -129,7 +129,7 @@ def _plot_temporal( @staticmethod def _get_interp_param( - start: K, intermediate: K, end: K, interpolation_parameter: Optional[float] = None + source: K, intermediate: K, target: K, interpolation_parameter: Optional[float] = None ) -> Numeric_t: ... @@ -293,8 +293,8 @@ def sankey( @d_mixins.dedent def push( self: TemporalMixinProtocol[K, B], - start: K, - end: K, + source: K, + target: K, data: Optional[Union[str, ArrayLike]] = None, subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, scale_by_marginals: bool = True, @@ -323,8 +323,8 @@ def push( """ result = self._apply( - start=start, - end=end, + source=source, + target=target, data=data, subset=subset, forward=True, @@ -338,8 +338,8 @@ def push( if key_added is not None: plot_vars = { - "start": start, - "end": end, + "source": source, + "target": target, "temporal_key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, @@ -352,8 +352,8 @@ def push( @d_mixins.dedent def pull( self: TemporalMixinProtocol[K, B], - start: K, - end: K, + source: K, + target: K, data: Optional[Union[str, ArrayLike]] = None, subset: Optional[Union[str, List[str], Tuple[int, int]]] = None, scale_by_marginals: bool = True, @@ -382,8 +382,8 @@ def pull( """ result = self._apply( - start=start, - end=end, + source=source, + target=target, data=data, subset=subset, forward=False, @@ -399,8 +399,8 @@ def pull( "temporal_key": self.temporal_key, "data": data if isinstance(data, str) else None, "subset": subset, - "start": start, - "end": end, + "source": source, + "target": target, } self.adata.obs[key_added] = self._flatten(result, key=self.temporal_key) Key.uns.set_plotting_vars(self.adata, AdataKeys.UNS, PlottingKeys.PULL, key_added, plot_vars) @@ -410,9 +410,9 @@ def pull( # TODO(michalk8): refactor me def _get_data( self: TemporalMixinProtocol[K, B], - start: K, + source: K, intermediate: Optional[K] = None, - end: Optional[K] = None, + target: Optional[K] = None, posterior_marginals: bool = True, *, only_start: bool = False, @@ -425,7 +425,7 @@ def _get_data( f"Expected `tag={Tag.POINT_CLOUD}`, " # type: ignore[union-attr] f"found `tag={self.problems[src, tgt].xy.tag}`." ) - if src == start: + if src == source: source_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr] if only_start: return source_data, self.problems[src, tgt].adata_src @@ -434,7 +434,7 @@ def _get_data( growth_rates_source = getattr(self.problems[src, tgt], attr) break else: - raise ValueError(f"No data found for `{start}` time point.") + raise ValueError(f"No data found for `{source}` time point.") for (src, tgt) in self.problems.keys(): if src == intermediate: intermediate_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr] @@ -443,11 +443,11 @@ def _get_data( else: raise ValueError(f"No data found for `{intermediate}` time point.") for (src, tgt) in self.problems.keys(): - if tgt == end: + if tgt == target: target_data = self.problems[src, tgt].xy.data_tgt # type: ignore[union-attr] break else: - raise ValueError(f"No data found for `{end}` time point.") + raise ValueError(f"No data found for `{target}` time point.") return ( # type: ignore[return-value] source_data, @@ -459,9 +459,9 @@ def _get_data( def compute_interpolated_distance( self: TemporalMixinProtocol[K, B], - start: K, + source: K, intermediate: K, - end: K, + target: K, interpolation_parameter: Optional[float] = None, n_interpolated_cells: Optional[int] = None, account_for_unbalancedness: bool = False, @@ -507,22 +507,22 @@ def compute_interpolated_distance( Wasserstein distance between OT-based interpolated distribution and the true cell distribution. """ source_data, _, intermediate_data, _, target_data = self._get_data( # type: ignore[misc] - start, + source, intermediate, - end, + target, posterior_marginals=posterior_marginals, only_start=False, ) interpolation_parameter = self._get_interp_param( - start, intermediate, end, interpolation_parameter=interpolation_parameter + source, intermediate, target, interpolation_parameter=interpolation_parameter ) n_interpolated_cells = n_interpolated_cells if n_interpolated_cells is not None else len(intermediate_data) interpolation = self._interpolate_gex_with_ot( number_cells=n_interpolated_cells, source_data=source_data, target_data=target_data, - start=start, - end=end, + source=source, + target=target, interpolation_parameter=interpolation_parameter, account_for_unbalancedness=account_for_unbalancedness, batch_size=batch_size, @@ -534,9 +534,9 @@ def compute_interpolated_distance( def compute_random_distance( self: TemporalMixinProtocol[K, B], - start: K, + source: K, intermediate: K, - end: K, + target: K, interpolation_parameter: Optional[float] = None, n_interpolated_cells: Optional[int] = None, account_for_unbalancedness: bool = False, @@ -571,11 +571,11 @@ def compute_random_distance( The Wasserstein distance between a randomly interpolated cell distribution and the true cell distribution. """ source_data, growth_rates_source, intermediate_data, _, target_data = self._get_data( # type: ignore[misc] - start, intermediate, end, posterior_marginals=posterior_marginals, only_start=False + source, intermediate, target, posterior_marginals=posterior_marginals, only_start=False ) interpolation_parameter = self._get_interp_param( - start, intermediate, end, interpolation_parameter=interpolation_parameter + source, intermediate, target, interpolation_parameter=interpolation_parameter ) n_interpolated_cells = n_interpolated_cells if n_interpolated_cells is not None else len(intermediate_data) @@ -592,9 +592,9 @@ def compute_random_distance( def compute_time_point_distances( self: TemporalMixinProtocol[K, B], - start: K, + source: K, intermediate: K, - end: K, + target: K, posterior_marginals: bool = True, backend: Literal["ott"] = "ott", **kwargs: Any, @@ -617,9 +617,9 @@ def compute_time_point_distances( %(kwargs_divergence)s """ source_data, _, intermediate_data, _, target_data = self._get_data( # type: ignore[misc] - start, + source, intermediate, - end, + target, posterior_marginals=posterior_marginals, only_start=False, ) @@ -693,16 +693,16 @@ def _interpolate_gex_with_ot( number_cells: int, source_data: ArrayLike, target_data: ArrayLike, - start: K, - end: K, + source: K, + target: K, interpolation_parameter: float, account_for_unbalancedness: bool = True, batch_size: int = 256, seed: Optional[int] = None, ) -> ArrayLike: rows_sampled, cols_sampled = self._sample_from_tmap( - source=start, - target=end, + source=source, + target=target, n_samples=number_cells, source_dim=len(source_data), target_dim=len(target_data), @@ -740,12 +740,12 @@ def _interpolate_gex_randomly( @staticmethod def _get_interp_param( - start: K, intermediate: K, end: K, interpolation_parameter: Optional[float] = None + source: K, intermediate: K, target: K, interpolation_parameter: Optional[float] = None ) -> Numeric_t: if TYPE_CHECKING: - assert isinstance(start, float) + assert isinstance(source, float) assert isinstance(intermediate, float) - assert isinstance(end, float) + assert isinstance(target, float) if interpolation_parameter is not None: if 0 < interpolation_parameter < 1: return interpolation_parameter @@ -753,10 +753,10 @@ def _get_interp_param( f"Expected interpolation parameter to be in interval `(0, 1)`, found `{interpolation_parameter}`." ) - if start < intermediate < end: - return (intermediate - start) / (end - start) + if source < intermediate < target: + return (intermediate - source) / (target - source) raise ValueError( - f"Expected intermediate time point to be in interval `({start}, {end})`, found `{intermediate}`." + f"Expected intermediate time point to be in interval `({source}, {target})`, found `{intermediate}`." ) @property diff --git a/tests/analysis_mixins/test_base_analysis.py b/tests/analysis_mixins/test_base_analysis.py index b5a10bcb2..aa01cd3a7 100644 --- a/tests/analysis_mixins/test_base_analysis.py +++ b/tests/analysis_mixins/test_base_analysis.py @@ -178,7 +178,7 @@ def test_compute_feature_correlation(self, adata_time: AnnData, method: Literal[ problem = problem.prepare("time", xy_callback="local-pca") problem[0, 1]._solution = MockSolverOutput(tmap) - adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(start=0, end=1).squeeze())) + adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) res = problem.compute_feature_correlation(obs_key=key_added, method=method) @@ -207,7 +207,7 @@ def test_compute_feature_correlation_subset( problem = problem.prepare("time", xy_callback="local-pca") problem[0, 1]._solution = MockSolverOutput(tmap) - adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(start=0, end=1).squeeze())) + adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) if isinstance(features, int): features = list(adata_time.var_names)[:features] @@ -246,7 +246,7 @@ def test_compute_feature_correlation_transcription_factors( problem = problem.prepare("time", xy_callback="local-pca") problem[0, 1]._solution = MockSolverOutput(tmap) - adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(start=0, end=1).squeeze())) + adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) if features[0] == "error": with np.testing.assert_raises(NotImplementedError): @@ -273,7 +273,7 @@ def test_seed_reproducible(self, adata_time: AnnData): problem = problem.prepare("time", xy_callback="local-pca") problem[0, 1]._solution = MockSolverOutput(tmap) - adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(start=0, end=1).squeeze())) + adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) res_a = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=0, method="perm_test") res_b = problem.compute_feature_correlation(obs_key=key_added, n_perms=10, n_jobs=1, seed=0, method="perm_test") @@ -298,7 +298,7 @@ def test_seed_reproducible_parallelized(self, adata_time: AnnData): problem = problem.prepare("time", xy_callback="local-pca") problem[0, 1]._solution = MockSolverOutput(tmap) - adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(start=0, end=1).squeeze())) + adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) res_a = problem.compute_feature_correlation( obs_key=key_added, n_perms=10, n_jobs=2, backend="threading", seed=0, method="perm_test" @@ -324,7 +324,7 @@ def test_confidence_level(self, adata_time: AnnData): problem = problem.prepare("time", xy_callback="local-pca") problem[0, 1]._solution = MockSolverOutput(tmap) - adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(start=0, end=1).squeeze())) + adata_time.obs[key_added] = np.hstack((np.zeros(n0), problem.pull(source=0, target=1).squeeze())) res_narrow = problem.compute_feature_correlation(obs_key=key_added, confidence_level=0.95) res_wide = problem.compute_feature_correlation(obs_key=key_added, confidence_level=0.99) diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index 057ec70dc..776e2cabe 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -41,7 +41,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: @pytest.fixture() def adata_pl_push(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "start": 0, "end": 1} + plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") Key.uns.set_plotting_vars(adata_time, AdataKeys.UNS, PlottingKeys.PUSH, PlottingDefaults.PUSH, plot_vars) @@ -60,7 +60,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData: @pytest.fixture() def adata_pl_pull(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) - plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "start": 0, "end": 1} + plot_vars = {"temporal_key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} adata_time.uns["celltype_colors"] = ["#cc1b1b", "#2ccc1b", "#cc1bcc"] adata_time.obs["celltype"] = adata_time.obs["celltype"].astype("category") Key.uns.set_plotting_vars(adata_time, AdataKeys.UNS, PlottingKeys.PULL, PlottingDefaults.PULL, plot_vars) diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index a0fc9ee3d..8418a9b52 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -129,9 +129,9 @@ def test_compute_time_point_distances_pipeline(self, adata_time: AnnData): problem = TemporalProblem(adata_time) problem.prepare("time") distance_source_intermediate, distance_intermediate_target = problem.compute_time_point_distances( - start=0, + source=0, intermediate=1, - end=2, + target=2, posterior_marginals=False, ) assert distance_source_intermediate > 0