Skip to content

Commit

Permalink
Backport PR #2425: Fix hub model test names (#2426)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Kim <46072231+martinkim0@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and martinkim0 authored Jan 22, 2024
1 parent f05d25b commit b7fae41
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions tests/hub/test_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def prep_scvi_minified_hub_model(save_path: str) -> HubModel:
anndata_version=anndata.__version__,
data_modalities=["rna"],
data_is_annotated=False,
description="scVI model trained on synthetid IID data and uploaded with the minified data.",
description="scVI model trained on synthetic IID data and uploaded with the minified data.",
)
return HubModel(model_path, metadata=metadata, model_card=card)

Expand Down Expand Up @@ -222,52 +222,52 @@ def test_hub_model_create_repo_hf(save_path: str):

hub_model = prep_scvi_hub_model(save_path)
hub_model.push_to_huggingface_hub(
"scvi-tools/test_scvi_create",
"scvi-tools/test-scvi-create",
os.environ["HF_API_TOKEN"],
repo_create=True,
)
delete_repo("scvi-tools/test_scvi_create", token=os.environ["HF_API_TOKEN"])
delete_repo("scvi-tools/test-scvi-create", token=os.environ["HF_API_TOKEN"])


@pytest.mark.private
def test_hub_model_push_to_hf(save_path: str):
hub_model = prep_scvi_hub_model(save_path)
hub_model.push_to_huggingface_hub(
"scvi-tools/test_scvi",
"scvi-tools/test-scvi",
os.environ["HF_API_TOKEN"],
repo_create=False,
)

hub_model = prep_scvi_no_anndata_hub_model(save_path)
hub_model.push_to_huggingface_hub(
"scvi-tools/test_scvi_no_anndata",
"scvi-tools/test-scvi-no-anndata",
os.environ["HF_API_TOKEN"],
repo_create=False,
push_anndata=False,
)

hub_model = prep_scvi_minified_hub_model(save_path)
hub_model.push_to_huggingface_hub(
"scvi-tools/test_scvi_minified",
"scvi-tools/test-scvi-minified",
os.environ["HF_API_TOKEN"],
repo_create=False,
)


@pytest.mark.private
def test_hub_model_pull_from_hf():
hub_model = HubModel.pull_from_huggingface_hub(repo_name="scvi-tools/test_scvi")
hub_model = HubModel.pull_from_huggingface_hub(repo_name="scvi-tools/test-scvi")
assert hub_model.model is not None
assert hub_model.adata is not None

hub_model = HubModel.pull_from_huggingface_hub(
repo_name="scvi-tools/test_scvi_minified"
repo_name="scvi-tools/test-scvi-minified"
)
assert hub_model.model is not None
assert hub_model.adata is not None

hub_model = HubModel.pull_from_huggingface_hub(
repo_name="scvi-tools/test_scvi_no_anndata"
repo_name="scvi-tools/test-scvi-no-anndata"
)
with pytest.raises(ValueError):
_ = hub_model.model
Expand All @@ -281,19 +281,19 @@ def test_hub_model_pull_from_hf():
@pytest.mark.private
def test_hub_model_push_to_s3(save_path: str):
hub_model = prep_scvi_hub_model(save_path)
hub_model.push_to_s3("scvi-tools", "tests/hub/test_scvi")
hub_model.push_to_s3("scvi-tools", "tests/hub/test-scvi")

hub_model = prep_scvi_no_anndata_hub_model(save_path)
with pytest.raises(ValueError):
hub_model.push_to_s3(
"scvi-tools", "tests/hub/test_scvi_no_anndata", push_anndata=True
"scvi-tools", "tests/hub/test-scvi-no-anndata", push_anndata=True
)
hub_model.push_to_s3(
"scvi-tools", "tests/hub/test_scvi_no_anndata", push_anndata=False
"scvi-tools", "tests/hub/test-scvi-no-anndata", push_anndata=False
)

hub_model = prep_scvi_minified_hub_model(save_path)
hub_model.push_to_s3("scvi-tools", "tests/hub/test_scvi_minified")
hub_model.push_to_s3("scvi-tools", "tests/hub/test-scvi-minified")


@pytest.mark.private
Expand All @@ -302,21 +302,21 @@ def test_hub_model_pull_from_s3():

hub_model = HubModel.pull_from_s3(
"scvi-tools",
"tests/hub/test_scvi",
"tests/hub/test-scvi",
)
assert hub_model.model is not None
assert hub_model.adata is not None

hub_model = HubModel.pull_from_s3("scvi-tools", "tests/hub/test_scvi_minified")
hub_model = HubModel.pull_from_s3("scvi-tools", "tests/hub/test-scvi-minified")
assert hub_model.model is not None
assert hub_model.adata is not None

with pytest.raises(ClientError):
hub_model = HubModel.pull_from_s3("scvi-tools", "tests/hub/test_scvi_no_anndata")
hub_model = HubModel.pull_from_s3("scvi-tools", "tests/hub/test-scvi-no-anndata")

hub_model = HubModel.pull_from_s3(
"scvi-tools",
"tests/hub/test_scvi_no_anndata",
"tests/hub/test-scvi-no-anndata",
pull_anndata=False,
)
with pytest.raises(ValueError):
Expand Down

0 comments on commit b7fae41

Please sign in to comment.