Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc changes #230

Merged
merged 12 commits into from
Dec 19, 2024
100 changes: 75 additions & 25 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def login(self, overwrite: bool = False, auto_open_browser: bool = True):
# =========================

def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
"""List all available datasets on the Polaris Hub.
"""List all available datasets (v1 and v2) on the Polaris Hub.
We prioritize v2 datasets over v1 datasets.

Args:
limit: The maximum number of datasets to return.
Expand All @@ -287,17 +288,42 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
A list of dataset names in the format `owner/dataset_name`.
"""
with ProgressIndicator(
start_msg="Fetching artifacts...",
success_msg="Fetched artifacts.",
start_msg="Fetching datasets...",
success_msg="Fetched datasets.",
error_msg="Failed to fetch datasets.",
):
response = self._base_request_to_hub(
url="/v1/dataset", method="GET", params={"limit": limit, "offset": offset}
# Step 1: Fetch enough v2 datasets to cover the offset and limit
v2_response = self._base_request_to_hub(
url="/v2/dataset", method="GET", params={"limit": limit + offset, "offset": 0}
)
response_data = response.json()
dataset_list = [bm["artifactId"] for bm in response_data["data"]]
v2_data = v2_response.json().get("data", [])
v2_datasets = [dataset["artifactId"] for dataset in v2_data]

# Apply offset and limit to v2 datasets
v2_datasets_offset = v2_datasets[offset : offset + limit]
mercuryseries marked this conversation as resolved.
Show resolved Hide resolved

# If v2 datasets satisfy the limit, return them
if len(v2_datasets_offset) >= limit:
mercuryseries marked this conversation as resolved.
Show resolved Hide resolved
return v2_datasets_offset

return dataset_list
# Step 2: Calculate the remaining limit and fetch v1 datasets
remaining_limit = max(0, limit - len(v2_datasets_offset))

v1_datasets = []
mercuryseries marked this conversation as resolved.
Show resolved Hide resolved
if remaining_limit > 0:
mercuryseries marked this conversation as resolved.
Show resolved Hide resolved
v1_response = self._base_request_to_hub(
url="/v1/dataset",
method="GET",
params={"limit": remaining_limit, "offset": max(0, offset - len(v2_datasets))},
)
v1_data = v1_response.json().get("data", [])
v1_datasets = [dataset["artifactId"] for dataset in v1_data]

# Combine the v2 and v1 datasets
combined_datasets = v2_datasets_offset + v1_datasets

# Ensure the final combined list respects the limit
return combined_datasets[:limit]
mercuryseries marked this conversation as resolved.
Show resolved Hide resolved

def get_dataset(
self,
Expand All @@ -322,7 +348,7 @@ def get_dataset(
error_msg="Failed to fetch dataset.",
):
try:
return self._get_v1_dataset(owner, name, ArtifactSubtype.STANDARD.value, verify_checksum)
return self._get_v1_dataset(owner, name, ArtifactSubtype.STANDARD, verify_checksum)
except PolarisRetrieveArtifactError:
# If the v1 dataset is not found, try to load a v2 dataset
return self._get_v2_dataset(owner, name)
Expand All @@ -347,7 +373,7 @@ def _get_v1_dataset(
"""
url = (
f"/v1/dataset/{owner}/{name}"
if artifact_type == ArtifactSubtype.STANDARD.value
if artifact_type == ArtifactSubtype.STANDARD
else f"/v2/competition/dataset/{owner}/{name}"
)
response = self._base_request_to_hub(url=url, method="GET")
Expand Down Expand Up @@ -407,18 +433,44 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]:
A list of benchmark names in the format `owner/benchmark_name`.
"""
with ProgressIndicator(
start_msg="Fetching artifacts...",
success_msg="Fetched artifacts.",
start_msg="Fetching benchmarks...",
success_msg="Fetched benchmarks.",
error_msg="Failed to fetch benchmarks.",
):
# TODO (cwognum): What to do with pagination, i.e. limit and offset?
response = self._base_request_to_hub(
url="/v1/benchmark", method="GET", params={"limit": limit, "offset": offset}
# Step 1: Fetch enough v2 benchmarks to cover the offset and limit
v2_response = self._base_request_to_hub(
url="/v2/benchmark", method="GET", params={"limit": limit + offset, "offset": 0}
)
response_data = response.json()
benchmarks_list = [f"{HubOwner(**bm['owner'])}/{bm['name']}" for bm in response_data["data"]]
v2_data = v2_response.json().get("data", [])
v2_benchmarks = [f"{HubOwner(**benchmark['owner'])}/{benchmark['name']}" for benchmark in v2_data]

# Apply offset and limit to v2 benchmarks
v2_benchmarks_offset = v2_benchmarks[offset : offset + limit]

# If v2 benchmarks satisfy the limit, return them
if len(v2_benchmarks_offset) >= limit:
return v2_benchmarks_offset
mercuryseries marked this conversation as resolved.
Show resolved Hide resolved

return benchmarks_list
# Step 2: Calculate the remaining limit and fetch v1 benchmarks
remaining_limit = max(0, limit - len(v2_benchmarks_offset))

v1_benchmarks = []
if remaining_limit > 0:
v1_response = self._base_request_to_hub(
url="/v1/benchmark",
method="GET",
params={"limit": remaining_limit, "offset": max(0, offset - len(v2_benchmarks))},
)
v1_data = v1_response.json().get("data", [])
v1_benchmarks = [
f"{HubOwner(**benchmark['owner'])}/{benchmark['name']}" for benchmark in v1_data
]

# Combine the v2 and v1 benchmarks
combined_benchmarks = v2_benchmarks_offset + v1_benchmarks

# Ensure the final combined list respects the limit
return combined_benchmarks[:limit]

def get_benchmark(
self,
Expand Down Expand Up @@ -559,7 +611,7 @@ def upload_dataset(

if isinstance(dataset, DatasetV1):
return self._upload_v1_dataset(
dataset, ArtifactSubtype.STANDARD.value, timeout, access, owner, if_exists
dataset, ArtifactSubtype.STANDARD, timeout, access, owner, if_exists
)
elif isinstance(dataset, DatasetV2):
return self._upload_v2_dataset(dataset, timeout, access, owner, if_exists)
Expand Down Expand Up @@ -607,7 +659,7 @@ def _upload_v1_dataset(
# We do so separately for the Zarr archive and Parquet file.
url = (
f"/v1/dataset/{dataset.artifact_id}"
if artifact_type == ArtifactSubtype.STANDARD.value
if artifact_type == ArtifactSubtype.STANDARD
else f"/v2/competition/dataset/{dataset.owner}/{dataset.name}"
)
response = self._base_request_to_hub(
Expand Down Expand Up @@ -651,7 +703,7 @@ def _upload_v1_dataset(
)

base_artifact_url = (
"datasets" if artifact_type == ArtifactSubtype.STANDARD.value else "/competition/datasets"
"datasets" if artifact_type == ArtifactSubtype.STANDARD else "/competition/datasets"
)
progress_indicator.update_success_msg(
f"Your {artifact_type} dataset has been successfully uploaded to the Hub. "
Expand Down Expand Up @@ -754,7 +806,7 @@ def upload_benchmark(
access: Grant public or private access to result
owner: Which Hub user or organization owns the artifact. Takes precedence over `benchmark.owner`.
"""
return self._upload_benchmark(benchmark, ArtifactSubtype.STANDARD.value, access, owner)
return self._upload_benchmark(benchmark, ArtifactSubtype.STANDARD, access, owner)

def _upload_benchmark(
self,
Expand Down Expand Up @@ -797,9 +849,7 @@ def _upload_benchmark(
benchmark_json["datasetArtifactId"] = benchmark.dataset.artifact_id
benchmark_json["access"] = access

path_params = (
"/v1/benchmark" if artifact_type == ArtifactSubtype.STANDARD.value else "/v2/competition"
)
path_params = "/v1/benchmark" if artifact_type == ArtifactSubtype.STANDARD else "/v2/competition"
url = f"{path_params}/{benchmark.owner}/{benchmark.name}"
response = self._base_request_to_hub(url=url, method="PUT", json=benchmark_json)
response_data = response.json()
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_multi_task_benchmark_multiple_test_sets(test_dataset, regression_metric
@pytest.fixture(scope="function")
def test_docking_dataset(tmp_path, sdf_files, test_org_owner):
# toy docking dataset
factory = DatasetFactory(tmp_path / "ligands.zarr")
factory = DatasetFactory(str(tmp_path / "ligands.zarr"))

converter = SDFConverter(mol_prop_as_cols=True)
factory.register_converter("sdf", converter)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ def _check_for_failure(_kwargs):
kwargs["target_cols"] = kwargs["target_cols"][1:] + ["iupac"]
_check_for_failure(kwargs)

# Input columns
kwargs = obj.model_dump()
kwargs["input_cols"] = kwargs["input_cols"][1:] + ["iupac"]
_check_for_failure(kwargs)

# --- Don't fail if not checksum is provided ---
# --- Don't fail if no checksum is provided ---
kwargs["md5sum"] = None
dataset = cls(**kwargs)
assert dataset.md5sum is not None
Expand Down
14 changes: 7 additions & 7 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_dataset_checksum(test_dataset):
def test_dataset_from_zarr(zarr_archive, tmp_path):
"""Test whether loading works when the zarr archive contains a single array or multiple arrays."""
archive = zarr_archive
dataset = create_dataset_from_file(archive, tmp_path / "data")
dataset = create_dataset_from_file(archive, str(tmp_path / "data"))

assert len(dataset.table) == 100
for i in range(100):
Expand All @@ -115,8 +115,8 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmp_path):
can be saved to and loaded from json.
"""

json_dir = tmp_path / "json"
zarr_dir = tmp_path / "zarr"
json_dir = str(tmp_path / "json")
zarr_dir = str(tmp_path / "zarr")

archive = zarr_archive
dataset = create_dataset_from_file(archive, zarr_dir)
Expand All @@ -132,8 +132,8 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmp_path):
def test_dataset_caching(zarr_archive, tmp_path):
"""Test whether the dataset remains the same after caching."""

original_dataset = create_dataset_from_file(zarr_archive, tmp_path / "original1")
cached_dataset = create_dataset_from_file(zarr_archive, tmp_path / "original2")
original_dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "original1"))
cached_dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "original2"))
assert original_dataset == cached_dataset

cached_dataset._cache_dir = str(tmp_path / "cached")
Expand All @@ -153,7 +153,7 @@ def test_dataset_index():

def test_dataset_in_memory_optimization(zarr_archive, tmp_path):
"""Check if optimization makes a default Zarr archive faster."""
dataset = create_dataset_from_file(zarr_archive, tmp_path / "dataset")
dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "dataset"))
subset = Subset(dataset=dataset, indices=range(100), input_cols=["A"], target_cols=["B"])

t1 = perf_counter()
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_dataset__get_item__():
def test_dataset__get_item__with_pointer_columns(zarr_archive, tmp_path):
"""Test the __getitem__() interface for a dataset with pointer columns (i.e. part of the data stored in Zarr)."""

dataset = create_dataset_from_file(zarr_archive, tmp_path / "data")
dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "data"))
root = zarr.open(zarr_archive)

# Get a specific cell
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_dataset_v2_load_to_memory(test_dataset_v2):


def test_dataset_v2_serialization(test_dataset_v2, tmp_path):
save_dir = tmp_path / "save_dir"
save_dir = str(tmp_path / "save_dir")
path = test_dataset_v2.to_json(save_dir)
new_dataset = DatasetV2.from_json(path)
for i in range(5):
Expand All @@ -86,7 +86,7 @@ def test_dataset_v1_v2_compatibility(test_dataset, tmp_path):
# We can thus also saved these same arrays to a Zarr archive
df = test_dataset.table

path = tmp_path / "data/v1v2.zarr"
path = str(tmp_path / "data" / "v1v2.zarr")

root = zarr.open(path, "w")
root.array("smiles", data=df["smiles"].values, dtype=object, object_codec=numcodecs.VLenUTF8())
Expand All @@ -96,7 +96,7 @@ def test_dataset_v1_v2_compatibility(test_dataset, tmp_path):
zarr.consolidate_metadata(path)

kwargs = test_dataset.model_dump(exclude=["table", "zarr_root_path"])
dataset = DatasetV2(**kwargs, zarr_root_path=str(path))
dataset = DatasetV2(**kwargs, zarr_root_path=path)

subset_1 = Subset(dataset=test_dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"])
subset_2 = Subset(dataset=dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"])
Expand Down
6 changes: 3 additions & 3 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path

import datamol as dm
import numpy as np
Expand All @@ -17,7 +17,7 @@
from polaris.utils.types import HubOwner


def test_result_to_json(tmp_path: str, test_user_owner: HubOwner):
def test_result_to_json(tmp_path: Path, test_user_owner: HubOwner):
scores = pd.DataFrame(
{
"Test set": ["A", "A", "A", "A", "B", "B", "B", "B"],
Expand All @@ -41,7 +41,7 @@ def test_result_to_json(tmp_path: str, test_user_owner: HubOwner):
contributors=["my-user", "other-user"],
)

path = os.path.join(tmp_path, "result.json")
path = str(tmp_path / "result.json")
result.to_json(path)

BenchmarkResults.from_json(path)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def _check_dataset(dataset, ground_truth, mol_props_as_col):

def test_sdf_zarr_conversion(sdf_file, caffeine, tmp_path):
"""Test conversion between SDF and Zarr with utility function"""
dataset = create_dataset_from_file(sdf_file, tmp_path / "archive.zarr")
dataset = create_dataset_from_file(sdf_file, str(tmp_path / "archive.zarr"))
_check_dataset(dataset, [caffeine], True)


@pytest.mark.parametrize("mol_props_as_col", [True, False])
def test_factory_sdf_with_prop_as_col(sdf_file, caffeine, tmp_path, mol_props_as_col):
"""Test conversion between SDF and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "archive.zarr")
factory = DatasetFactory(str(tmp_path / "archive.zarr"))

converter = SDFConverter(mol_prop_as_cols=mol_props_as_col)
factory.register_converter("sdf", converter)
Expand All @@ -60,7 +60,7 @@ def test_factory_sdf_with_prop_as_col(sdf_file, caffeine, tmp_path, mol_props_as

def test_zarr_to_zarr_conversion(zarr_archive, tmp_path):
"""Test conversion between Zarr and Zarr with utility function"""
dataset = create_dataset_from_file(zarr_archive, tmp_path / "archive.zarr")
dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "archive.zarr"))
assert len(dataset) == 100
assert len(dataset.columns) == 2
assert all(c in dataset.columns for c in ["A", "B"])
Expand All @@ -71,7 +71,7 @@ def test_zarr_to_zarr_conversion(zarr_archive, tmp_path):
def test_zarr_with_factory_pattern(zarr_archive, tmp_path):
"""Test conversion between Zarr and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "archive.zarr")
factory = DatasetFactory(str(tmp_path / "archive.zarr"))
converter = ZarrConverter()
factory.register_converter("zarr", converter)
factory.add_from_file(zarr_archive)
Expand All @@ -90,7 +90,7 @@ def test_zarr_with_factory_pattern(zarr_archive, tmp_path):

def test_factory_pdb(pdbs_structs, pdb_paths, tmp_path):
"""Test conversion between PDB file and Zarr with factory pattern"""
factory = DatasetFactory(tmp_path / "pdb.zarr")
factory = DatasetFactory(str(tmp_path / "pdb.zarr"))

converter = PDBConverter()
factory.register_converter("pdb", converter)
Expand All @@ -104,7 +104,7 @@ def test_factory_pdb(pdbs_structs, pdb_paths, tmp_path):
def test_factory_pdbs(pdbs_structs, pdb_paths, tmp_path):
"""Test conversion between PDB files and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "pdbs.zarr")
factory = DatasetFactory(str(tmp_path / "pdbs.zarr"))

converter = PDBConverter()
factory.register_converter("pdb", converter)
Expand All @@ -119,7 +119,7 @@ def test_factory_pdbs(pdbs_structs, pdb_paths, tmp_path):
def test_pdbs_zarr_conversion(pdbs_structs, pdb_paths, tmp_path):
"""Test conversion between PDBs and Zarr with utility function"""

dataset = create_dataset_from_files(pdb_paths, tmp_path / "pdbs_2.zarr", axis=0)
dataset = create_dataset_from_files(pdb_paths, str(tmp_path / "pdbs_2.zarr"), axis=0)

assert dataset.table.shape[0] == len(pdb_paths)
_check_pdb_dataset(dataset, pdbs_structs)
Expand All @@ -128,7 +128,7 @@ def test_pdbs_zarr_conversion(pdbs_structs, pdb_paths, tmp_path):
def test_factory_sdfs(sdf_files, caffeine, ibuprofen, tmp_path):
"""Test conversion between SDF and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "sdfs.zarr")
factory = DatasetFactory(str(tmp_path / "sdfs.zarr"))

converter = SDFConverter(mol_prop_as_cols=True)
factory.register_converter("sdf", converter)
Expand All @@ -142,7 +142,7 @@ def test_factory_sdfs(sdf_files, caffeine, ibuprofen, tmp_path):
def test_factory_sdf_pdb(sdf_file, pdb_paths, caffeine, pdbs_structs, tmp_path):
"""Test conversion between SDF and PDB from files to Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "sdf_pdb.zarr")
factory = DatasetFactory(str(tmp_path / "sdf_pdb.zarr"))

sdf_converter = SDFConverter(mol_prop_as_cols=False)
factory.register_converter("sdf", sdf_converter)
Expand All @@ -158,7 +158,7 @@ def test_factory_sdf_pdb(sdf_file, pdb_paths, caffeine, pdbs_structs, tmp_path):


def test_factory_from_files_same_column(sdf_files, pdb_paths, tmp_path):
factory = DatasetFactory(tmp_path / "files.zarr")
factory = DatasetFactory(str(tmp_path / "files.zarr"))

sdf_converter = SDFConverter(mol_prop_as_cols=False)
factory.register_converter("sdf", sdf_converter)
Expand Down
Loading
Loading