Skip to content

Commit

Permalink
Merge pull request #958 from haddocking/956-add-ilrmsd-to-capri_clt-t…
Browse files Browse the repository at this point in the history
…ables

update `caprieval` integration tests
  • Loading branch information
rvhonorato authored Aug 1, 2024
2 parents 62b7968 + 53fec1b commit a91d8d0
Showing 1 changed file with 222 additions and 102 deletions.
324 changes: 222 additions & 102 deletions integration_tests/test_caprieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import tempfile
from pathlib import Path
from typing import Union

import pytest

Expand Down Expand Up @@ -39,6 +40,79 @@ def model_list():
]


@pytest.fixture
def expected_clt_data() -> list[dict[str, Union[int, str, float]]]:
return [
{
"cluster_rank": "-",
"cluster_id": "-",
"n": 2,
"under_eval": "yes",
"score": float("nan"),
"score_std": float("nan"),
"irmsd": 4.163,
"irmsd_std": 4.163,
"fnat": 0.525,
"fnat_std": 0.475,
"lrmsd": 10.469,
"lrmsd_std": 10.469,
"dockq": 0.537,
"dockq_std": 0.463,
"ilrmsd": 9.124,
"ilrmsd_std": 9.124,
"air": float("nan"),
"air_std": float("nan"),
"bsa": float("nan"),
"bsa_std": float("nan"),
"desolv": float("nan"),
"desolv_std": float("nan"),
"elec": float("nan"),
"elec_std": float("nan"),
"total": float("nan"),
"total_std": float("nan"),
"vdw": float("nan"),
"vdw_std": float("nan"),
"caprieval_rank": 1,
}
]


@pytest.fixture
def expected_ss_data() -> list[dict[str, Union[int, str, float]]]:
return [
{
"model": "",
"md5": "-",
"caprieval_rank": 1,
"score": float("nan"),
"irmsd": 0.000,
"fnat": 1.000,
"lrmsd": 0.000,
"ilrmsd": 0.000,
"dockq": 1.000,
"cluster_id": "-",
"cluster_ranking": "-",
"model-cluster_ranking": "-",
"energy_term": 0.000,
},
{
"model": "",
"md5": "-",
"caprieval_rank": 2,
"score": float("nan"),
"irmsd": 8.327,
"fnat": 0.050,
"lrmsd": 20.937,
"ilrmsd": 18.248,
"dockq": 0.074,
"cluster_id": "-",
"cluster_ranking": "-",
"model-cluster_ranking": "-",
"energy_term": 0.000,
},
]


class MockPreviousIO:
def __init__(self, path):
self.path = path
Expand Down Expand Up @@ -71,7 +145,134 @@ def output(self):
return None


def evaluate_caprieval_execution(module: CaprievalModule, model_list):
def _cast_float_str_int(v: Union[int, str, float]) -> Union[int, str, float]:
"""Helper function to cast a value string to a float, int or str."""
try:
return int(v)
except ValueError:
try:
return float(v)
except ValueError:
return v


def _compare_polymorphic_data(
expected_data: list[dict[str, Union[int, str, float]]],
oberseved_data: list[dict[str, Union[int, str, float]]],
):
"""Helper function to compare a list of dictionaries with polymorphic values."""
for k, v in zip(expected_data, oberseved_data):
for key in k:
v1 = k[key]
v2 = v[key]

# Check the type match
assert type(v1) == type(v2), f"Type mismatch for {key}"

# Check float
if isinstance(v1, float):
if math.isnan(v1):
assert isinstance(v2, float) and math.isnan(
v2
), f"Value mismatch for {key}"
else:
assert (
isinstance(v2, (int, float))
and pytest.approx(v1, rel=1e-3) == v2
), f"Value mismatch for {key}"

# Check int
elif isinstance(v1, int):
assert v1 == v2, f"Value mismatch for {key}"

# Check str
elif isinstance(v1, str):
assert v1 == v2, f"Value mismatch for {key}"

# Value is not float, int or str
else:
raise ValueError(f"Unknown type for {key}")


def _check_capri_ss_tsv(
capri_file: str, expected_data: list[dict[str, Union[int, str, float]]]
):
"""Helper function to check the content of the capri_ss.tsv file."""
with open(capri_file) as f:
lines = f.readlines()

# Check the header
expected_header_cols = list(expected_data[0].keys())
observed_header_cols = lines[0].strip().split("\t")

# Check if all they have the same lenght
assert len(observed_header_cols) == len(expected_header_cols), "Header mismatch"

for col_name in expected_header_cols:
assert col_name in observed_header_cols, f"{col_name} not found in the header"

oberseved_data: list[dict[str, Union[int, str, float]]] = []
data = lines[1:]
for line in data:
values = line.strip().split("\t")

# Check there is one value for each column
assert len(values) == len(expected_header_cols), "Values mismatch"

data_dict = {}
for h, v in zip(expected_header_cols, values):
data_dict[h] = _cast_float_str_int(v)

oberseved_data.append(data_dict)

# Cannot compare the names of the models, since the observed will be a random string
[d.pop("model") for d in expected_data], [d.pop("model") for d in oberseved_data]

_compare_polymorphic_data(expected_data, oberseved_data)


def _check_capri_clt_tsv(
capri_file: str, expected_data: list[dict[str, Union[int, str, float]]]
):
"""Helper function to check the content of the capri_clt.tsv file."""
with open(capri_file) as f:
lines = f.readlines()

# There are several `#` lines in the file, these are comments and can be ignored
lines = [line for line in lines if not line.startswith("#")]

# Check header
expected_header_cols = list(expected_data[0].keys())
observed_header_cols = lines[0].strip().split("\t")

# Check if all the columns are present
assert len(observed_header_cols) == len(expected_header_cols), "Header mismatch"

for col in expected_header_cols:
assert col in observed_header_cols, f"{col} not found in the header"

data = lines[1:]
oberseved_data: list[dict[str, Union[int, str, float]]] = []
for line in data:
values = line.strip().split("\t")

# Check if there is one value for each column
assert len(values) == len(expected_header_cols), "Values mismatch"

data_dic = {}
for h, v in zip(expected_header_cols, values):
data_dic[h] = _cast_float_str_int(v)

oberseved_data.append(data_dic)

assert len(oberseved_data) == len(expected_data), "Data mismatch"

_compare_polymorphic_data(expected_data, oberseved_data)


def evaluate_caprieval_execution(
module: CaprievalModule, model_list, ss_data, clt_data
):
"""Helper function to check if `caprieval` executed properly."""

# Check if the files were written
Expand All @@ -92,118 +293,37 @@ def evaluate_caprieval_execution(module: CaprievalModule, model_list):
assert module.output_models[0].file_name == model_list[0].file_name
assert module.output_models[1].file_name == model_list[1].file_name

# The models do not hold the capri metrics, so check the output files to see if they were written
with open(Path(module.path, "capri_ss.tsv")) as f:
lines = f.readlines()

# Check the values
assert len(lines) == 3, "There should be 3 lines in the capri_ss.tsv"

# Check the header
# model md5 caprieval_rank score irmsd fnat lrmsd ilrmsd dockq cluster_id cluster_ranking model-cluster_ranking
expected_colnames = [
"model",
"md5",
"caprieval_rank",
"score",
"irmsd",
"fnat",
"lrmsd",
"ilrmsd",
"dockq",
"cluster_ranking",
"model-cluster_ranking",
"energy_term",
]
header = lines[0].strip().split("\t")
for col_name in expected_colnames:
assert col_name in header, f"{col_name} not found in the header"

# Check the values
data = lines[1:]
expected_data = [
{
"model": "",
"md5": "-",
"caprieval_rank": 1,
"score": float("nan"),
"irmsd": 0.000,
"fnat": 1.000,
"lrmsd": 0.000,
"ilrmsd": 0.000,
"dockq": 1.000,
"cluster_id": "-",
"cluster_ranking": "-",
"model-cluster_ranking": "-",
"energy_term": 0.000,
},
{
"model": "",
"md5": "-",
"caprieval_rank": 2,
"score": float("nan"),
"irmsd": 8.327,
"fnat": 0.050,
"lrmsd": 20.937,
"ilrmsd": 18.248,
"dockq": 0.074,
"cluster_id": "-",
"cluster_ranking": "-",
"model-cluster_ranking": "-",
"energy_term": 0.000,
},
]
oberseved_data = []
for line in data:
values = line.strip().split("\t")
data_dict = {
"model": "", # don't check this, it's a path and will change
"md5": str(values[1]),
"caprieval_rank": int(values[2]),
"score": float(values[3]),
"irmsd": float(values[4]),
"fnat": float(values[5]),
"lrmsd": float(values[6]),
"ilrmsd": float(values[7]),
"dockq": float(values[8]),
"cluster_id": str(values[9]),
"cluster_ranking": str(values[10]),
"model-cluster_ranking": str(values[11]),
"energy_term": float(values[12]),
}
oberseved_data.append(data_dict)

for k, v in zip(expected_data, oberseved_data):
for key in k:
v1 = k[key]
v2 = v[key]
_check_capri_ss_tsv(
capri_file=str(Path(module.path, "capri_ss.tsv")),
expected_data=ss_data,
)

# Check the type
assert type(v1) == type(v2), f"Type mismatch for {key}"

if isinstance(v1, float):
if math.isnan(v1):
assert math.isnan(v2), f"Value mismatch for {key}"
else:
assert (
pytest.approx(v1, rel=1e-3) == v2
), f"Value mismatch for {key}"
else:
assert v1 == v2, f"Value mismatch for {key}"
_check_capri_clt_tsv(
capri_file=str(Path(module.path, "capri_clt.tsv")),
expected_data=clt_data,
)


def test_caprieval_default(caprieval_module, model_list):
def test_caprieval_default(
caprieval_module, model_list, expected_ss_data, expected_clt_data
):

caprieval_module.previous_io = MockPreviousIO(path=caprieval_module.path)
caprieval_module.run()

evaluate_caprieval_execution(caprieval_module, model_list)
evaluate_caprieval_execution(
caprieval_module, model_list, expected_ss_data, expected_clt_data
)


def test_caprieval_less_io(caprieval_module, model_list):
def test_caprieval_less_io(
caprieval_module, model_list, expected_ss_data, expected_clt_data
):
caprieval_module.previous_io = MockPreviousIO(path=caprieval_module.path)
caprieval_module.params["less_io"] = True

caprieval_module.run()

evaluate_caprieval_execution(caprieval_module, model_list)
evaluate_caprieval_execution(
caprieval_module, model_list, expected_ss_data, expected_clt_data
)

0 comments on commit a91d8d0

Please sign in to comment.