-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
315b532
commit 50141ae
Showing
2 changed files
with
167 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# This script finds the most similar dataset from the atlas for a given user-uploaded dataset | ||
# It calculates similarity scores and returns the best matching dataset along with its configurations | ||
|
||
import argparse | ||
import json | ||
|
||
import pandas as pd | ||
import scanpy as sc | ||
|
||
from dance import logger | ||
from dance.atlas.sc_similarity.anndata_similarity import AnnDataSimilarity, get_anndata | ||
from dance.settings import DANCEDIR, SIMILARITYDIR | ||
|
||
|
||
def calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query): | ||
"""Calculate similarity scores between source data and atlas datasets. | ||
Args: | ||
source_data: User uploaded AnnData object | ||
tissue: Target tissue type | ||
atlas_datasets: List of candidate datasets from atlas | ||
reduce_error: Flag for error reduction mode - when True, applies a significant penalty | ||
to configurations in the atlas that produced errors | ||
in_query: Flag for query mode - when True, ranks similarity based on query performance, | ||
when False, ranks based on inter-atlas comparison | ||
Returns: | ||
Dictionary containing similarity scores for each atlas dataset | ||
""" | ||
with open( | ||
SIMILARITYDIR / | ||
f"data/similarity_weights_results/{'reduce_error_' if reduce_error else ''}{'in_query_' if in_query else ''}sim_dict.json", | ||
encoding='utf-8') as f: | ||
sim_dict = json.load(f) | ||
feature_name = sim_dict[tissue]["feature_name"] | ||
w1 = sim_dict[tissue]["weight1"] | ||
w2 = 1 - w1 | ||
ans = {} | ||
for target_file in atlas_datasets: | ||
logger.info(f"calculating similarity for {target_file}") | ||
atlas_data = get_anndata(tissue=tissue.capitalize(), species="human", filetype="h5ad", | ||
train_dataset=[f"{target_file}"], data_dir=str(DANCEDIR / "examples/tuning/temp_data")) | ||
similarity_calculator = AnnDataSimilarity(adata1=source_data, adata2=atlas_data, sample_size=10, | ||
init_random_state=42, n_runs=1, tissue=tissue) | ||
sim_target = similarity_calculator.get_similarity_matrix_A2B(methods=[feature_name, "metadata_sim"]) | ||
ans[target_file] = sim_target[feature_name] * w1 + sim_target["metadata_sim"] * w2 | ||
return ans | ||
|
||
|
||
def main(args): | ||
"""Main function to process user data and find the most similar atlas dataset. | ||
Args: | ||
args: Arguments containing: | ||
- tissue: Target tissue type | ||
- data_dir: Directory containing the source data | ||
- source_file: Name of the source file | ||
Returns: | ||
tuple containing: | ||
- ans_file: ID of the most similar dataset | ||
- ans_conf: Preprocess configuration dictionary for different cell type annotation methods | ||
- ans_value: Similarity score of the best matching dataset | ||
""" | ||
reduce_error = False | ||
in_query = True | ||
tissue = args.tissue | ||
tissue = tissue.lower() | ||
conf_data = pd.read_excel(SIMILARITYDIR / "data/Cell Type Annotation Atlas.xlsx", sheet_name=tissue) | ||
atlas_datasets = list(conf_data[conf_data["queryed"] == False]["dataset_id"]) | ||
source_data = sc.read_h5ad(f"{args.data_dir}/{args.source_file}.h5ad") | ||
|
||
ans = calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query) | ||
ans_file = max(ans, key=ans.get) | ||
ans_value = ans[ans_file] | ||
ans_conf = { | ||
method: conf_data.loc[conf_data["dataset_id"] == ans_file, f"{method}_step2_best_yaml"].iloc[0] | ||
for method in ["cta_celltypist", "cta_scdeepsort", "cta_singlecellnet", "cta_actinn"] | ||
} | ||
return ans_file, ans_conf, ans_value | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--tissue", default="Brain") | ||
parser.add_argument("--data_dir", default=str(DANCEDIR / "examples/tuning/temp_data/train/human")) | ||
parser.add_argument("--source_file", default="human_Brain364348b4-bc34-4fe1-a851-60d99e36cafa_data") | ||
|
||
args = parser.parse_args() | ||
ans_file, ans_conf, ans_value = main(args) | ||
print(ans_file, ans_conf, ans_value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Test suite for the Atlas similarity calculation functionality. | ||
This test verifies that the main function correctly returns: | ||
1. The most similar dataset from the atlas | ||
2. Its corresponding configuration settings | ||
3. The similarity score | ||
The test ensures: | ||
- Return value types are correct | ||
- Similarity score is within valid range (0-1) | ||
- Configuration dictionary contains all required cell type annotation methods | ||
""" | ||
|
||
import json | ||
import sys | ||
|
||
import pandas as pd | ||
|
||
from dance.settings import ATLASDIR, DANCEDIR, SIMILARITYDIR | ||
|
||
sys.path.append(str(ATLASDIR)) | ||
from demos.main import main | ||
|
||
from dance import logger | ||
|
||
|
||
def test_main(): | ||
# Construct test parameters with a sample Brain tissue dataset | ||
class Args: | ||
tissue = "Brain" | ||
data_dir = str(DANCEDIR / "examples/tuning/temp_data/train/human") | ||
source_file = "human_Brain364348b4-bc34-4fe1-a851-60d99e36cafa_data" | ||
|
||
args = Args() | ||
logger.info(f"testing main with args: {args}") | ||
source_id = "3643" | ||
|
||
# Execute main function with test parameters | ||
ans_file, ans_conf, ans_value = main(args) | ||
|
||
# Verify return value types and ranges | ||
assert isinstance(ans_file, str), "ans_file should be a string type" | ||
assert isinstance(ans_value, float), "ans_value should be a float type" | ||
assert 0 <= ans_value <= 1, "Similarity value should be between 0 and 1" | ||
|
||
# Verify configuration dictionary structure and content | ||
expected_methods = ["cta_celltypist", "cta_scdeepsort", "cta_singlecellnet", "cta_actinn"] | ||
assert isinstance(ans_conf, dict), "ans_conf should be a dictionary type" | ||
assert set(ans_conf.keys()) == set(expected_methods), "ans_conf should contain all expected methods" | ||
assert all(isinstance(v, str) for v in ans_conf.values()), "All configuration values should be string type" | ||
|
||
# Verify consistency with Excel spreadsheet results | ||
data = pd.read_excel(SIMILARITYDIR / f"data/new_sim/{args.tissue.lower()}_similarity.xlsx", sheet_name=source_id, | ||
index_col=0) | ||
reduce_error = False | ||
in_query = True | ||
# Read weights | ||
with open( | ||
SIMILARITYDIR / | ||
f"data/similarity_weights_results/{'reduce_error_' if reduce_error else ''}{'in_query_' if in_query else ''}sim_dict.json", | ||
encoding='utf-8') as f: | ||
sim_dict = json.load(f) | ||
feature_name = sim_dict[args.tissue.lower()]["feature_name"] | ||
w1 = sim_dict[args.tissue.lower()]["weight1"] | ||
w2 = 1 - w1 | ||
|
||
# Calculate similarity in Excel | ||
data.loc["similarity"] = data.loc[feature_name] * w1 + data.loc["metadata_sim"] * w2 | ||
expected_file = data.loc["similarity"].idxmax() | ||
expected_value = data.loc["similarity", expected_file] | ||
|
||
# Verify result consistency with Excel | ||
assert abs(ans_value - expected_value) < 1e-4, "Calculated similarity value does not match Excel value" | ||
assert ans_file == expected_file, "Selected most similar dataset does not match Excel result" |