diff --git a/scdataloader/utils.py b/scdataloader/utils.py index ce56604..625e124 100644 --- a/scdataloader/utils.py +++ b/scdataloader/utils.py @@ -127,15 +127,15 @@ def getBiomartTable( cache_folder = os.path.expanduser(cache_folder) createFoldersFor(cache_folder) - cachefile = os.path.join(cache_folder, ".biomart.csv") + cachefile = os.path.join(cache_folder, ".biomart.parquet") if useCache & os.path.isfile(cachefile): print("fetching gene names from biomart cache") - res = pd.read_csv(cachefile) + res = pd.read_parquet(cachefile) else: print("downloading gene names from biomart") res = _fetchFromServer(ensemble_server, attr + attributes, database=database) - res.to_csv(cachefile, index=False) + res.to_parquet(cachefile, index=False) res.columns = attr + attributes if type(res) is not type(pd.DataFrame()): raise ValueError("should be a dataframe") @@ -368,7 +368,14 @@ def load_genes(organisms: Union[str, list] = "NCBITaxon:9606"): # "NCBITaxon:10 genesdf["organism"] = organism organismdf.append(genesdf) organismdf = pd.concat(organismdf) - for col in ["source_id", "run_id", "created_by_id", "updated_at", "stable_id", "created_at"]: + for col in [ + "source_id", + "run_id", + "created_by_id", + "updated_at", + "stable_id", + "created_at", + ]: if col in organismdf.columns: organismdf.drop(columns=[col], inplace=True) return organismdf diff --git a/tests/test_base.py b/tests/test_base.py index f9cf656..1651591 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -45,6 +45,7 @@ def test_base(): how="most expr", # for the collator (most expr genes only will be selected) max_len=1000, # only the 1000 most expressed batch_size=64, + do_gene_pos=False, num_workers=1, clss_to_weight=["organism_ontology_term_id", "cell_type_ontology_term_id"], clss_to_pred=["organism_ontology_term_id", "cell_type_ontology_term_id"],