Skip to content

Commit

Permalink
incorporate requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
MUCDK committed Dec 9, 2022
1 parent 778a805 commit 8178df9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 27 deletions.
10 changes: 5 additions & 5 deletions src/moscot/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def handle_joint_attr(
"y_key": joint_attr,
}
return xy, kwargs
if isinstance(joint_attr, Mapping):
if isinstance(joint_attr, Mapping): # input mapping does not distinguish between x and y as it's a shared space
joint_attr = dict(joint_attr)
if "attr" in joint_attr and joint_attr["attr"] == "X":
if "attr" in joint_attr and joint_attr["attr"] == "X": # we have a point cloud
return {"x_attr": "X", "y_attr": "X"}, kwargs
if "attr" in joint_attr and joint_attr["attr"] == "obsm":
if "attr" in joint_attr and joint_attr["attr"] == "obsm": # we have a point cloud
if "key" not in joint_attr:
raise KeyError("`key` must be provided when `attr` is `obsm`.")
xy = {
Expand All @@ -96,8 +96,8 @@ def handle_joint_attr(
"y_key": joint_attr["key"],
}
return xy, kwargs
if joint_attr.get("tag", None) == "cost_matrix":
if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp":
if joint_attr.get("tag", None) == "cost_matrix": # if this is True we have custom cost matrix or moscot cost
if len(joint_attr) == 2 or kwargs.get("attr", None) == "obsp": # in this case we have a custom cost matrix
joint_attr.setdefault("cost", "custom")
joint_attr.setdefault("attr", "obsp")
kwargs["xy_callback"] = "cost-matrix"
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/base/_base_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def concat(x: ArrayLike, y: ArrayLike) -> ArrayLike:
if scaler is not None:
data = scaler.fit_transform(data)
return {"xy": TaggedArray(data[:n], data[n:], tag=Tag.POINT_CLOUD)}
if term in ("x", "y"):
if term in ("x", "y"): # if we don't have a shared space, then adata_y is always None
logger.info(f"Computing pca with `n_comps={n_comps}` for `{term}` using `{msg}`")
x = sc.pp.pca(x, n_comps=n_comps, **kwargs)
if scaler is not None:
Expand Down
44 changes: 23 additions & 21 deletions src/moscot/problems/base/_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _callback_handler(
key_2: K,
problem: B,
*,
callback: Union[Literal["local-pca"], Callback_t],
callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
**kwargs: Any,
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None:
Expand All @@ -115,6 +115,8 @@ def verify_data(data: Mapping[Literal["xy", "x", "y"], TaggedArray]) -> None:
if not isinstance(val, TaggedArray):
raise TypeError(f"Expected value for `{key}` to be a `TaggedArray`, found `{type(val)}`.")

if callback is None:
return {}
if callback == "local-pca":
callback = problem._local_pca_callback

Expand Down Expand Up @@ -150,24 +152,19 @@ def _create_problems(
tgt_name = tgt

problem = self._create_problem(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
if xy_callback is not None:
xy_data = self._callback_handler(
term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs
)
else:
xy_data = {}
if x_callback is not None:
x_data = self._callback_handler(
term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs
)
else:
x_data = {}
if y_callback is not None:
y_data = self._callback_handler(
term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs
)
else:
y_data = {}

xy_data = self._callback_handler(
term="xy", key_1=src, key_2=tgt, problem=problem, callback=xy_callback, **xy_callback_kwargs
)

x_data = self._callback_handler(
term="x", key_1=src, key_2=tgt, problem=problem, callback=x_callback, **x_callback_kwargs
)

y_data = self._callback_handler(
term="y", key_1=src, key_2=tgt, problem=problem, callback=y_callback, **y_callback_kwargs
)

kws = {**kwargs, **xy_data, **x_data, **y_data} # type: ignore[arg-type]

if isinstance(problem, BirthDeathProblem):
Expand Down Expand Up @@ -626,7 +623,7 @@ def _callback_handler(
key_2: K,
problem: B,
*,
callback: Union[Literal["local-pca", "cost-matrix"], Callback_t],
callback: Optional[Union[Literal["local-pca", "cost-matrix"], Callback_t]] = None,
**kwargs: Any,
) -> Mapping[Literal["xy", "x", "y"], TaggedArray]:
# TODO(michalk8): better name?
Expand Down Expand Up @@ -657,10 +654,15 @@ def _cost_matrix_callback(

linear_cost_matrix = data[mask, :][:, mask_2]
if issparse(linear_cost_matrix):
logger.warning("Linear cost matrix being densified.")
linear_cost_matrix = linear_cost_matrix.A
return {"xy": TaggedArray(linear_cost_matrix, tag=Tag.COST_MATRIX)}

if term in ("x", "y"):
return {term: TaggedArray(data[mask, :][:, mask], tag=Tag.COST_MATRIX)}
quad_cost_matrix = data[mask, :][:, mask_2]
if issparse(quad_cost_matrix):
logger.warning("Quadratic cost matrix being densified.")
quad_cost_matrix = quad_cost_matrix.A
return {term: TaggedArray(quad_cost_matrix, tag=Tag.COST_MATRIX)}

raise ValueError(f"Expected `term` to be one of `x`, `y`, or `xy`, found `{term!r}`.")

0 comments on commit 8178df9

Please sign in to comment.