Skip to content

Commit

Permalink
update user interface
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Feb 6, 2025
1 parent 315b532 commit 50141ae
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 0 deletions.
93 changes: 93 additions & 0 deletions examples/atlas/demos/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# This script finds the most similar dataset from the atlas for a given user-uploaded dataset
# It calculates similarity scores and returns the best matching dataset along with its configurations

import argparse
import json

import pandas as pd
import scanpy as sc

from dance import logger
from dance.atlas.sc_similarity.anndata_similarity import AnnDataSimilarity, get_anndata
from dance.settings import DANCEDIR, SIMILARITYDIR


def calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query):
"""Calculate similarity scores between source data and atlas datasets.
Args:
source_data: User uploaded AnnData object
tissue: Target tissue type
atlas_datasets: List of candidate datasets from atlas
reduce_error: Flag for error reduction mode - when True, applies a significant penalty
to configurations in the atlas that produced errors
in_query: Flag for query mode - when True, ranks similarity based on query performance,
when False, ranks based on inter-atlas comparison
Returns:
Dictionary containing similarity scores for each atlas dataset
"""
with open(
SIMILARITYDIR /
f"data/similarity_weights_results/{'reduce_error_' if reduce_error else ''}{'in_query_' if in_query else ''}sim_dict.json",
encoding='utf-8') as f:
sim_dict = json.load(f)
feature_name = sim_dict[tissue]["feature_name"]
w1 = sim_dict[tissue]["weight1"]
w2 = 1 - w1
ans = {}
for target_file in atlas_datasets:
logger.info(f"calculating similarity for {target_file}")
atlas_data = get_anndata(tissue=tissue.capitalize(), species="human", filetype="h5ad",
train_dataset=[f"{target_file}"], data_dir=str(DANCEDIR / "examples/tuning/temp_data"))
similarity_calculator = AnnDataSimilarity(adata1=source_data, adata2=atlas_data, sample_size=10,
init_random_state=42, n_runs=1, tissue=tissue)
sim_target = similarity_calculator.get_similarity_matrix_A2B(methods=[feature_name, "metadata_sim"])
ans[target_file] = sim_target[feature_name] * w1 + sim_target["metadata_sim"] * w2
return ans


def main(args):
"""Main function to process user data and find the most similar atlas dataset.
Args:
args: Arguments containing:
- tissue: Target tissue type
- data_dir: Directory containing the source data
- source_file: Name of the source file
Returns:
tuple containing:
- ans_file: ID of the most similar dataset
- ans_conf: Preprocess configuration dictionary for different cell type annotation methods
- ans_value: Similarity score of the best matching dataset
"""
reduce_error = False
in_query = True
tissue = args.tissue
tissue = tissue.lower()
conf_data = pd.read_excel(SIMILARITYDIR / "data/Cell Type Annotation Atlas.xlsx", sheet_name=tissue)
atlas_datasets = list(conf_data[conf_data["queryed"] == False]["dataset_id"])
source_data = sc.read_h5ad(f"{args.data_dir}/{args.source_file}.h5ad")

ans = calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query)
ans_file = max(ans, key=ans.get)
ans_value = ans[ans_file]
ans_conf = {
method: conf_data.loc[conf_data["dataset_id"] == ans_file, f"{method}_step2_best_yaml"].iloc[0]
for method in ["cta_celltypist", "cta_scdeepsort", "cta_singlecellnet", "cta_actinn"]
}
return ans_file, ans_conf, ans_value


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tissue", default="Brain")
parser.add_argument("--data_dir", default=str(DANCEDIR / "examples/tuning/temp_data/train/human"))
parser.add_argument("--source_file", default="human_Brain364348b4-bc34-4fe1-a851-60d99e36cafa_data")

args = parser.parse_args()
ans_file, ans_conf, ans_value = main(args)
print(ans_file, ans_conf, ans_value)
74 changes: 74 additions & 0 deletions tests/atlas/test_atlas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Test suite for the Atlas similarity calculation functionality.
This test verifies that the main function correctly returns:
1. The most similar dataset from the atlas
2. Its corresponding configuration settings
3. The similarity score
The test ensures:
- Return value types are correct
- Similarity score is within valid range (0-1)
- Configuration dictionary contains all required cell type annotation methods
"""

import json
import sys

import pandas as pd

from dance.settings import ATLASDIR, DANCEDIR, SIMILARITYDIR

sys.path.append(str(ATLASDIR))
from demos.main import main

from dance import logger


def test_main():
# Construct test parameters with a sample Brain tissue dataset
class Args:
tissue = "Brain"
data_dir = str(DANCEDIR / "examples/tuning/temp_data/train/human")
source_file = "human_Brain364348b4-bc34-4fe1-a851-60d99e36cafa_data"

args = Args()
logger.info(f"testing main with args: {args}")
source_id = "3643"

# Execute main function with test parameters
ans_file, ans_conf, ans_value = main(args)

# Verify return value types and ranges
assert isinstance(ans_file, str), "ans_file should be a string type"
assert isinstance(ans_value, float), "ans_value should be a float type"
assert 0 <= ans_value <= 1, "Similarity value should be between 0 and 1"

# Verify configuration dictionary structure and content
expected_methods = ["cta_celltypist", "cta_scdeepsort", "cta_singlecellnet", "cta_actinn"]
assert isinstance(ans_conf, dict), "ans_conf should be a dictionary type"
assert set(ans_conf.keys()) == set(expected_methods), "ans_conf should contain all expected methods"
assert all(isinstance(v, str) for v in ans_conf.values()), "All configuration values should be string type"

# Verify consistency with Excel spreadsheet results
data = pd.read_excel(SIMILARITYDIR / f"data/new_sim/{args.tissue.lower()}_similarity.xlsx", sheet_name=source_id,
index_col=0)
reduce_error = False
in_query = True
# Read weights
with open(
SIMILARITYDIR /
f"data/similarity_weights_results/{'reduce_error_' if reduce_error else ''}{'in_query_' if in_query else ''}sim_dict.json",
encoding='utf-8') as f:
sim_dict = json.load(f)
feature_name = sim_dict[args.tissue.lower()]["feature_name"]
w1 = sim_dict[args.tissue.lower()]["weight1"]
w2 = 1 - w1

# Calculate similarity in Excel
data.loc["similarity"] = data.loc[feature_name] * w1 + data.loc["metadata_sim"] * w2
expected_file = data.loc["similarity"].idxmax()
expected_value = data.loc["similarity", expected_file]

# Verify result consistency with Excel
assert abs(ans_value - expected_value) < 1e-4, "Calculated similarity value does not match Excel value"
assert ans_file == expected_file, "Selected most similar dataset does not match Excel result"

0 comments on commit 50141ae

Please sign in to comment.