Skip to content

Commit

Permalink
fix(fixtures): fix model-loading fixtures and utilities (MODFLOW-USGS#12
Browse files Browse the repository at this point in the history
)

* rename modflow_devtools.misc.get_models to get_model_paths
* fix get_model_paths filtering by package
* sort paths returned by get_model_paths
* refactor get_packages function
* expand tests
  • Loading branch information
wpbonelli authored Nov 11, 2022
1 parent 3c63aaa commit 1e5fabd
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 61 deletions.
10 changes: 5 additions & 5 deletions modflow_devtools/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, List, Optional

import pytest
from modflow_devtools.misc import get_mf6_ftypes, get_models
from modflow_devtools.misc import get_model_paths, get_packages

# temporary directory fixtures

Expand Down Expand Up @@ -174,7 +174,7 @@ def pytest_generate_tests(metafunc):
key = "test_model_mf6"
if key in metafunc.fixturenames:
models = (
get_models(
get_model_paths(
Path(repos_path) / "modflow6-testmodels" / "mf6",
prefix="test",
excluded=["test205_gwtbuy-henrytidal"],
Expand All @@ -189,7 +189,7 @@ def pytest_generate_tests(metafunc):
key = "test_model_mf5to6"
if key in metafunc.fixturenames:
models = (
get_models(
get_model_paths(
Path(repos_path) / "modflow6-testmodels" / "mf5to6",
prefix="test",
namefile="*.nam",
Expand All @@ -205,7 +205,7 @@ def pytest_generate_tests(metafunc):
key = "large_test_model"
if key in metafunc.fixturenames:
models = (
get_models(
get_model_paths(
Path(repos_path) / "modflow6-largetestmodels",
prefix="test",
namefile="*.nam",
Expand Down Expand Up @@ -292,7 +292,7 @@ def get_examples():
for name, namefiles in examples.items():
ftypes = []
for namefile in namefiles:
ftype = get_mf6_ftypes(namefile, packages_selected)
ftype = get_packages(namefile, packages_selected)
if ftype not in ftypes:
ftypes += ftype
if len(ftypes) > 0:
Expand Down
86 changes: 48 additions & 38 deletions modflow_devtools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,17 @@ def get_current_branch() -> str:
raise ValueError(f"Could not determine current branch: {stderr}")


def get_mf6_ftypes(namefile_path: PathLike, ftypekeys: List[str]) -> List[str]:
def get_packages(namefile_path: PathLike) -> List[str]:
"""
Return a list of FTYPES that are in the name file and in ftypekeys.
Return a list of packages used by the model defined in the given namefile.
Parameters
----------
namefile_path : str
path to a MODFLOW 6 name file
ftypekeys : list
list of desired FTYPEs
namefile_path : PathLike
path to MODFLOW 6 name file
Returns
-------
ftypes : list
list of FTYPES that match ftypekeys in namefile
list of package types
"""
with open(namefile_path, "r") as f:
lines = f.readlines()
Expand All @@ -126,22 +123,29 @@ def get_mf6_ftypes(namefile_path: PathLike, ftypekeys: List[str]) -> List[str]:
if len(ll) < 2:
continue

if ll[0] in ["#", "!"]:
l = ll[0].lower()
if any(l.startswith(c) for c in ["#", "!", "data", "list"]) or l in [
"begin",
"end",
"memory_print_option",
]:
continue

for key in ftypekeys:
if key.lower() in ll[0].lower():
ftypes.append(ll[0])
# strip "6" from package name
l = l.replace("6", "")

return ftypes
ftypes.append(l.lower())

return list(set(ftypes))

def has_packages(namefile_path: PathLike, packages: List[str]) -> bool:
ftypes = [item.upper() for item in get_mf6_ftypes(namefile_path, packages)]
return len(ftypes) > 0

def has_package(namefile_path: PathLike, package: str) -> bool:
"""Determines whether the model with the given namefile contains the selected package"""
packages = get_packages(namefile_path)
return package.lower in packages

def get_models(

def get_model_paths(
path: PathLike,
prefix: str = None,
namefile: str = "mfsim.nam",
Expand All @@ -150,7 +154,12 @@ def get_models(
packages=None,
) -> List[Path]:
"""
Find models in the given filesystem location.
Find models recursively in the given location.
Models can be filtered or excluded by pattern,
filtered by packages used or naming convention
for namefiles, or by parent folder name prefix.
The path to the model folder (i.e., the folder
containing the model's namefile) is returned.
"""

# if path doesn't exist, return empty list
Expand All @@ -161,7 +170,7 @@ def get_models(
namfile_paths = [
p
for p in Path(path).rglob(
f"{prefix}*/{namefile}" if prefix else namefile
f"{prefix}*/**/{namefile}" if prefix else namefile
)
]

Expand All @@ -172,37 +181,36 @@ def get_models(
if (not excluded or not any(e in str(p) for e in excluded))
]

# filter by package (optional)
# filter by package
if packages:
namfile_paths = [
p
for p in namfile_paths
if (has_packages(p, packages) if packages else True)
]

# get model dir paths
filtered = []
for nfp in namfile_paths:
nf_pkgs = get_packages(nfp)
shared = set(nf_pkgs).intersection(
set([p.lower() for p in packages])
)
if any(shared):
filtered.append(nfp)
namfile_paths = filtered

# get model folder paths
model_paths = [p.parent for p in namfile_paths]

# filter by model name (optional)
# filter by model name
if selected:
model_paths = [
model
for model in model_paths
if any(s in model.name for s in selected)
]

# exclude dev examples on master or release branches
branch = get_current_branch()
if "master" in branch.lower() or "release" in branch.lower():
model_paths = [
model for model in model_paths if "_dev" not in model.name.lower()
]

return model_paths
return sorted(model_paths)


def is_connected(hostname):
"""See https://stackoverflow.com/a/20913928/ to test hostname."""
"""
Tests whether the given URL is accessible.
See https://stackoverflow.com/a/20913928/."""
try:
host = socket.gethostbyname(hostname)
s = socket.create_connection((host, 80), 2)
Expand All @@ -214,6 +222,8 @@ def is_connected(hostname):


def is_in_ci():
"""Determines whether the current process is running GitHub Actions CI"""

# if running in GitHub Actions CI, "CI" variable always set to true
# https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables
return bool(environ.get("CI", None))
Expand All @@ -222,7 +232,7 @@ def is_in_ci():
def is_github_rate_limited() -> Optional[bool]:
"""
Determines if a GitHub API rate limit is applied to the current IP.
Note that running this function will consume an API request!
Running this function will consume an API request!
Returns
-------
Expand Down
23 changes: 9 additions & 14 deletions modflow_devtools/test/test_executables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@
from modflow_devtools.executables import Executables
from modflow_devtools.misc import add_sys_path, get_suffixes

_bin_path = Path(environ.get("BIN_PATH")).expanduser()
_bin_path = Path(environ.get("BIN_PATH")).expanduser().absolute()
_ext, _ = get_suffixes(sys.platform)


@pytest.fixture
def bin_path(module_tmpdir) -> Path:
return _bin_path.absolute()


@pytest.mark.skipif(not _bin_path.is_dir(), reason="bin directory not found")
def test_get_path(bin_path):
def test_get_path():
with add_sys_path(str(_bin_path)):
ext, _ = get_suffixes(sys.platform)
assert (
Expand All @@ -26,22 +21,22 @@ def test_get_path(bin_path):
)


def test_get_version(bin_path):
with add_sys_path(str(bin_path)):
ver_str = Executables.get_version("mf6", path=bin_path).partition(" ")
def test_get_version():
with add_sys_path(str(_bin_path)):
ver_str = Executables.get_version("mf6", path=_bin_path).partition(" ")
print(ver_str)
version = int(ver_str[0].split(".")[0])
assert version >= 6


@pytest.fixture
def exes(bin_path):
return Executables(mf6=bin_path / f"mf6{_ext}")
def exes():
return Executables(mf6=_bin_path / f"mf6{_ext}")


def test_executables_mapping(bin_path, exes):
def test_executables_mapping(exes):
print(exes.mf6)
assert exes.mf6 == bin_path / f"mf6{_ext}"
assert exes.mf6 == _bin_path / f"mf6{_ext}"


def test_executables_usage(exes):
Expand Down
6 changes: 4 additions & 2 deletions modflow_devtools/test/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,11 @@ def test_test_model_mf6(test_model_mf6):

def test_test_model_mf5to6(test_model_mf5to6):
assert isinstance(test_model_mf5to6, Path)
assert len(list(test_model_mf5to6.glob("*.nam"))) >= 1
assert any(list(test_model_mf5to6.glob("*.nam")))


def test_large_test_model(large_test_model):
assert isinstance(large_test_model, Path)
assert (large_test_model / "mfsim.nam").is_file()
assert (large_test_model / "mfsim.nam").is_file() or any(
list(large_test_model.glob("*.nam"))
)
61 changes: 59 additions & 2 deletions modflow_devtools/test/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,59 @@
def test_set_dir():
pass
import os
from os import environ
from pathlib import Path

import pytest
from modflow_devtools.misc import get_model_paths, get_packages, set_dir


def test_set_dir(tmp_path):
assert Path(os.getcwd()) != tmp_path
with set_dir(tmp_path):
assert Path(os.getcwd()) == tmp_path
assert Path(os.getcwd()) != tmp_path


_repos_path = Path(environ.get("REPOS_PATH")).expanduser().absolute()
_examples_repo_path = _repos_path / "modflow6-examples"
_examples_path = _examples_repo_path / "examples"
_example_paths = (
sorted(list(_examples_path.glob("ex-*")))
if _examples_path.is_dir()
else []
)


@pytest.mark.skipif(not any(_example_paths), reason="examples not found")
def test_has_packages():
example_path = _example_paths[0]
packages = get_packages(example_path / "mfsim.nam")
assert set(packages) == {"tdis", "gwf", "ims"}


@pytest.mark.skipif(not any(_example_paths), reason="examples not found")
def test_get_model_paths():
paths = get_model_paths(_examples_path)
assert len(paths) == 127

paths = get_model_paths(_examples_path, namefile="*.nam")
assert len(paths) == 339


def test_get_model_paths_exclude_patterns():
paths = get_model_paths(_examples_path, excluded=["gwt"])
assert len(paths) == 63


def test_get_model_paths_select_prefix():
paths = get_model_paths(_examples_path, prefix="ex2")
assert not any(paths)


def test_get_model_paths_select_patterns():
paths = get_model_paths(_examples_path, selected=["gwf"])
assert len(paths) == 70


def test_get_model_paths_select_packages():
paths = get_model_paths(_examples_path, namefile="*.nam", packages=["wel"])
assert len(paths) == 64

0 comments on commit 1e5fabd

Please sign in to comment.