Skip to content

Commit

Permalink
Merge pull request #612 from Sichao25/pred
Browse files Browse the repository at this point in the history
Debug dynast.py
  • Loading branch information
Xiaojieqiu authored Jan 29, 2024
2 parents c99d0c3 + 152a9db commit a053564
Showing 1 changed file with 49 additions and 43 deletions.
92 changes: 49 additions & 43 deletions dynamo/preprocessing/dynast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ def lambda_correction(
"""Use lambda (cell-wise detection rate) to estimate the labelled RNA.
Args:
adata: an adata object generated from dynast.
lambda_key: the key to the cell-wise detection rate. Defaults to "lambda".
inplace: whether to inplace update the layers. If False, new layers that append '_corrected" to the existing
adata: An adata object generated from dynast.
lambda_key: The key to the cell-wise detection rate. Defaults to "lambda".
inplace: Whether to inplace update the layers. If False, new layers that append '_corrected" to the existing
will be used to store the updated data. Defaults to True.
copy: whether to copy the adata object or update adata object inplace. Defaults to False.
copy: Whether to copy the adata object or update adata object inplace. Defaults to False.
Raises:
ValueError: the `lambda_key` cannot be found in `adata.obs`
ValueError: The `lambda_key` cannot be found in `adata.obs`.
ValueError: The adata object has to include labeling layers.
ValueError: `data_type` is set to 'splicing_labeling' but the existing layers in the adata object don't meet the
requirements.
ValueError: `data_type` is set to 'labeling' but the existing layers in the adata object don't meet the
requirements.
Returns:
A new AnnData object that are updated with lambda corrected layers if `copy` is true. Otherwise, return None.
"""
Expand All @@ -54,18 +56,22 @@ def lambda_correction(
logger.info("identify the data type..", indent_level=1)
all_layers = adata.layers.keys()

has_ul = np.any([i.contains("ul_") for i in all_layers])
has_un = np.any([i.contains("un_") for i in all_layers])
has_sl = np.any([i.contains("sl_") for i in all_layers])
has_sn = np.any([i.contains("sn_") for i in all_layers])
has_ul = np.any(["ul_" in i for i in all_layers])
has_un = np.any(["un_" in i for i in all_layers])
has_sl = np.any(["sl_" in i for i in all_layers])
has_sn = np.any(["sn_" in i for i in all_layers])

has_l = np.any([i.contains("_l_") for i in all_layers])
has_n = np.any([i.contains("_n_") for i in all_layers])
has_l = np.any(["_l_" in i for i in all_layers])
has_n = np.any(["_n_" in i for i in all_layers])

if sum(has_ul + has_un + has_sl + has_sn) == 4:
if np.count_nonzero([has_ul, has_un, has_sl, has_sn]) == 4:
datatype = "splicing_labeling"
elif sum(has_l + has_n):
elif np.count_nonzero([has_l, has_n]):
datatype = "labeling"
else:
raise ValueError(
"the adata object has to include labeling layers."
)

logger.info(f"the data type identified is {datatype}", indent_level=2)

Expand All @@ -74,44 +80,44 @@ def lambda_correction(
layers, match_tot_layer = [], []
for layer in all_layers:
if "ul_" in layer:
layers += layer
match_tot_layer += "unspliced"
layers.append(layer)
match_tot_layer.append("unspliced")
elif "un_" in layer:
layers += layer
match_tot_layer += "unspliced"
layers.append(layer)
match_tot_layer.append("unspliced")
elif "sl_" in layer:
layers += layer
match_tot_layer += "spliced"
layers.append(layer)
match_tot_layer.append("spliced")
elif "sn_" in layer:
layers += layer
match_tot_layer += "spliced"
layers.append(layer)
match_tot_layer.append("spliced")
elif "spliced" in layer:
layers += layer
layers.append(layer)
elif "unspliced" in layer:
layers += layer
layers.append(layer)

if len(layers) != 6:
raise ValueError(
"the adata object has to include ul, un, sl, sn, unspliced, spliced, "
"six relevant layers for splicing and labeling quantified datasets."
)
if len(layers) != 6:
raise ValueError(
"the adata object has to include ul, un, sl, sn, unspliced, spliced, "
"six relevant layers for splicing and labeling quantified datasets."
)
elif datatype == "labeling":
layers, match_tot_layer = [], []
for layer in all_layers:
if "_l_" in layer:
layers += layer
match_tot_layer += ["total"]
layers.append(layer)
match_tot_layer.append("total")
elif "_n_" in layer:
layers += layer
match_tot_layer += ["total"]
layers.append(layer)
match_tot_layer.append("total")
elif "total" in layer:
layers += layer
layers.append(layer)

if len(layers) != 3:
raise ValueError(
"the adata object has to include labeled, unlabeled, three relevant layers for labeling quantified "
"datasets."
)
if len(layers) != 3:
raise ValueError(
"the adata object has to include labeled, unlabeled, three relevant layers for labeling quantified "
"datasets."
)

logger.info("detection rate correction starts", indent_level=1)
for i, layer in enumerate(main_tqdm(layers, desc="iterating all relevant layers")):
Expand All @@ -133,9 +139,9 @@ def lambda_correction(

else:
if inplace:
adata.layers[layer] = cur_total - adata.layers[layer[i - 1]]
adata.layers[layer] = cur_total - adata.layers[layers[i - 1]]
else:
adata.layers[layer + "_corrected"] = cur_total - adata.layers[layer[i - 1]]
adata.layers[layer + "_corrected"] = cur_total - adata.layers[layers[i - 1]]

logger.finish_progress(progress_name="lambda_correction")

Expand All @@ -148,15 +154,15 @@ def sparse_mimmax(A: csr_matrix, B: csr_matrix, type="min") -> csr_matrix:
"""Return the element-wise minimum/maximum of sparse matrices `A` and `B`.
Args:
A: The first sparse matrix
B: The second sparse matrix
A: The first sparse matrix.
B: The second sparse matrix.
type: The type of calculation, either "min" or "max". Defaults to "min".
Returns:
A sparse matrix that contain the element-wise maximal or minimal of two sparse matrices.
"""

AgtB = (A < B).astype(int) if type == "min" else (A > B).astype(int)
M = AgtB.multiply(A - B) + B
M = np.multiply(AgtB, A - B) + B

return M

0 comments on commit a053564

Please sign in to comment.