Skip to content

Commit

Permalink
Merge pull request #72 from Lotfollahi-lab/revision
Browse files Browse the repository at this point in the history
Revision
  • Loading branch information
sebastianbirk authored Aug 22, 2024
2 parents d22ef4e + 58ffcc7 commit 46a63d0
Show file tree
Hide file tree
Showing 9 changed files with 1,508 additions and 1,921 deletions.
781 changes: 323 additions & 458 deletions docs/tutorials/notebooks/mouse_brain_multimodal.ipynb

Large diffs are not rendered by default.

563 changes: 285 additions & 278 deletions docs/tutorials/notebooks/mouse_cns_sample_integration.ipynb

Large diffs are not rendered by default.

954 changes: 261 additions & 693 deletions docs/tutorials/notebooks/mouse_cns_single_sample.ipynb

Large diffs are not rendered by default.

639 changes: 293 additions & 346 deletions docs/tutorials/notebooks/mouse_cns_spatial_reference_mapping.ipynb

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions src/nichecompass/models/nichecompass.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ class NicheCompass(BaseModelMixin):
in masks but can be learned de novo).
cat_covariates_embeds_nums:
List of number of embedding nodes for all categorical covariates.
use_cuda_if_available:
If `True`, use cuda if available.
seed:
Random seed to get reproducible results.
kwargs:
NicheCompass kwargs (to support legacy versions).
"""
Expand Down Expand Up @@ -240,6 +243,8 @@ def __init__(self,
n_addon_gp: int=100,
cat_covariates_embeds_nums: Optional[List[int]]=None,
include_edge_kl_loss: bool=True,
use_cuda_if_available: bool=True,
seed: int=0,
**kwargs):
self.adata = adata
self.adata_atac = adata_atac
Expand Down Expand Up @@ -282,6 +287,15 @@ def __init__(self,
self.active_gp_thresh_ratio_ = active_gp_thresh_ratio
self.active_gp_type_ = active_gp_type
self.include_edge_kl_loss_ = include_edge_kl_loss
self.seed_ = seed

# Set seed for reproducibility
np.random.seed(self.seed_)
if use_cuda_if_available & torch.cuda.is_available():
torch.cuda.manual_seed(self.seed_)
torch.manual_seed(self.seed_)
else:
torch.manual_seed(self.seed_)

# Retrieve gene program masks
if gp_targets_mask_key in adata.varm:
Expand Down Expand Up @@ -590,7 +604,7 @@ def train(self,
lambda_l1_masked: float=0.,
l1_targets_categories: Optional[list]=["target_gene"],
l1_sources_categories: Optional[list]=None,
lambda_l1_addon: float=0.,
lambda_l1_addon: float=30.,
edge_val_ratio: float=0.1,
node_val_ratio: float=0.1,
edge_batch_size: int=256,
Expand Down
2 changes: 2 additions & 0 deletions src/nichecompass/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ def __init__(self,
print("\n--- INITIALIZING TRAINER ---")

# Set seed and use GPU if available
np.random.seed(self.seed_)
if use_cuda_if_available & torch.cuda.is_available():
torch.cuda.manual_seed(self.seed_)
torch.manual_seed(self.seed_)
self.device = torch.device("cuda")
else:
torch.manual_seed(self.seed_)
Expand Down
6 changes: 4 additions & 2 deletions src/nichecompass/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from .gene_programs import (add_gps_from_gp_dict_to_adata,
extract_gp_dict_from_collectri_tf_network,
extract_gp_dict_from_nichenet_lrt_interactions,
extract_gp_dict_from_mebocost_es_interactions,
extract_gp_dict_from_mebocost_ms_interactions,
extract_gp_dict_from_omnipath_lr_interactions,
filter_and_combine_gp_dict_gps,
filter_and_combine_gp_dict_gps_v2,
get_unique_genes_from_gp_dict)

__all__ = ["add_gps_from_gp_dict_to_adata",
Expand All @@ -25,9 +26,10 @@
"visualize_communication_gp_network",
"extract_gp_dict_from_collectri_tf_network",
"extract_gp_dict_from_nichenet_lrt_interactions",
"extract_gp_dict_from_mebocost_es_interactions",
"extract_gp_dict_from_mebocost_ms_interactions",
"extract_gp_dict_from_omnipath_lr_interactions",
"filter_and_combine_gp_dict_gps",
"filter_and_combine_gp_dict_gps_v2",
"get_gene_annotations",
"generate_enriched_gp_info_plots",
"plot_non_zero_gene_count_means_dist",
Expand Down
460 changes: 321 additions & 139 deletions src/nichecompass/utils/gene_programs.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/nichecompass/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def load_R_file_as_df(R_file_path: str,
save_df_to_disk: bool=False,
df_save_path: Optional[str]=None) -> pd.DataFrame:
"""
Helper to load an R file either from ´url´ if specified or from ´file_path´
on disk and convert it to a pandas DataFrame.
Helper to load an R file either from ´url´ if specified or from
´R_file_path´ on disk and convert it to a pandas DataFrame.
Parameters
----------
Expand All @@ -39,7 +39,7 @@ def load_R_file_as_df(R_file_path: str,
"""
if url is None:
if not os.path.exists(R_file_path):
raise ValueError("Please specify a valid ´file_path´ or ´url´.")
raise ValueError("Please specify a valid ´R_file_path´ or ´url´.")
result_odict = pyreadr.read_r(R_file_path)
else:
result_odict = pyreadr.read_r(pyreadr.download_file(url, R_file_path))
Expand Down

0 comments on commit 46a63d0

Please sign in to comment.