Skip to content

Commit

Permalink
feat: Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisjared committed Jul 11, 2024
1 parent 7f3a458 commit 386acd7
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ INTERMEDIATES=data/intermediates
# Domain info filename. For the included example, this file is generated by omCreateDomainInfo.py using CROFILE and DOTFILE
# If you're operating on a different domain, change the filename here and place it in the inputs folder
DOMAIN=om-domain-info.nc
DOMAIN_NAME=aust-test
DOMAIN_NAME=aust10km
DOMAIN_VERSION=v1.0.0

# Remote storage for input files
Expand Down
18 changes: 7 additions & 11 deletions scripts/omDownloadInputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,18 @@
from collections.abc import Iterable

import attrs
from openmethane_prior.config import PriorConfig, load_config_from_env
from openmethane_prior.config import load_config_from_env
from openmethane_prior.inputs import download_input_file


def download_input_files(
config: PriorConfig, download_path: pathlib.Path, fragments: Iterable[str]
remote: str, download_path: pathlib.Path, fragments: Iterable[str]
) -> list[pathlib.Path]:
"""
Download input files from a remote location
Parameters
----------
config
OpenMethane-Prior configuration
download_path
Path to download the files to
fragments
Expand All @@ -54,16 +52,14 @@ def download_input_files(
List of input files that have been fetched or found locally.
"""
download_path.mkdir(parents=True, exist_ok=True)

downloaded_files = []
for name, url_fragment in fragments:
save_path = config.as_input_file(url_fragment).absolute()
for url_fragment in fragments:
save_path = download_path / url_fragment

if save_path.is_relative_to(config.input_path):
if not save_path.resolve().is_relative_to(download_path.resolve()):
raise ValueError(f"Check download fragment: {url_fragment}")

download_input_file(config.remote, url_fragment, save_path)
download_input_file(remote, url_fragment, save_path)
downloaded_files.append(save_path)
return downloaded_files

Expand All @@ -78,7 +74,7 @@ def download_input_files(
layer_fragments.append(config.input_domain.url_fragment())

download_input_files(
config=config,
remote=config.remote,
download_path=config.input_path,
fragments=layer_fragments,
)
34 changes: 25 additions & 9 deletions src/openmethane_prior/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class LayerInputs:


@attrs.frozen()
class InputDomain:
class PublishedInputDomain:
"""
Input domain configuration
Expand Down Expand Up @@ -59,13 +59,19 @@ def url_fragment(self) -> str:
class PriorConfig:
"""Configuration used to describe the prior data sources and the output directories."""

domain: str
remote: str
input_path: pathlib.Path
output_path: pathlib.Path
intermediates_path: pathlib.Path

input_domain: InputDomain | None
input_domain: PublishedInputDomain | str
"""Input domain specification
If provided, use a published domain as the input domain.
Otherwise, a file named `output_domain` is used as the input domain.
"""
output_domain: str
"""Name of the output domain file"""
layer_inputs: LayerInputs

def as_input_file(self, name: str | pathlib.Path) -> pathlib.Path:
Expand Down Expand Up @@ -118,13 +124,22 @@ def domain_cell_area(self):

@property
def input_domain_file(self):
"""Get the filename of the input domain"""
return self.as_input_file(self.domain)
"""
Get the filename of the input domain
Uses a published domain if it is provided otherwise uses a user-specified file name
"""
if isinstance(self.input_domain, PublishedInputDomain):
return self.as_input_file(self.input_domain.url_fragment())
elif isinstance(self.input_domain, str):
return self.as_input_file(self.input_domain)
else:
raise TypeError("Could not interpret the 'input_domain' field")

@property
def output_domain_file(self):
"""Get the filename of the output domain"""
return self.as_output_file(f"out-{self.domain}")
return self.as_output_file(self.output_domain)


def load_config_from_env(**overrides: typing.Any) -> PriorConfig:
Expand All @@ -143,21 +158,22 @@ def load_config_from_env(**overrides: typing.Any) -> PriorConfig:
env.read_env(verbose=True)

if env.str("DOMAIN_NAME", None) and env.str("DOMAIN_VERSION", None):
input_domain = InputDomain(
input_domain = PublishedInputDomain(
name=env.str("DOMAIN_NAME"),
version=env.str("DOMAIN_VERSION"),
)
else:
# Default to using a user-specified file as the input domain
# TODO: Log?
input_domain = None
input_domain = env.str("DOMAIN")

options = dict(
domain=env("DOMAIN"),
remote=env("PRIOR_REMOTE"),
input_path=env.path("INPUTS", "data/inputs"),
output_path=env.path("OUTPUTS", "data/outputs"),
intermediates_path=env.path("INTERMEDIATES", "data/processed"),
input_domain=input_domain,
output_domain=env.str("OUTPUT_DOMAIN", "out-om-domain-info.nc"),
layer_inputs=LayerInputs(
electricity_path=env.path("CH4_ELECTRICITY"),
oil_gas_path=env.path("CH4_OILGAS"),
Expand Down
6 changes: 4 additions & 2 deletions src/openmethane_prior/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def download_input_file(remote_url: str, url_fragment: str, save_path: pathlib.P
save_path.parent.mkdir(parents=True, exist_ok=True)

with requests.get(url, stream=True, timeout=30) as response:
response.raise_for_status()

with open(save_path, mode="wb") as file:
for chunk in response.iter_content(chunk_size=10 * 1024):
file.write(chunk)
Expand Down Expand Up @@ -105,8 +107,8 @@ def check_input_files(config: PriorConfig):
if len(errors) > 0:
print(
"Some required files are missing. "
"Suggest running omDownloadInputs.py if you're using the default input file set, "
"and omCreateDomainInfo.py if you haven't already. See issues below."
"Suggest running omDownloadInputs.py if you're using the default input file set. "
"See issues below."
)
print("\n".join(errors))
sys.exit(1)
Expand Down
138 changes: 89 additions & 49 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import dotenv
import pytest
import xarray as xr

from openmethane_prior.config import PriorConfig, load_config_from_env
from openmethane_prior.config import PriorConfig, PublishedInputDomain, load_config_from_env
from scripts.omDownloadInputs import download_input_files
from scripts.omPrior import run_prior

Expand All @@ -35,19 +34,44 @@ def env(monkeypatch, root_dir):


@pytest.fixture()
def cro_xr(config) -> xr.Dataset:
return xr.open_dataset(config.cro_file)
def config(tmp_path_factory) -> PriorConfig:
"""Default configuration"""
input_dir = tmp_path_factory.mktemp("inputs")
intermediate_dir = tmp_path_factory.mktemp("intermediates")
output_dir = tmp_path_factory.mktemp("outputs")

return load_config_from_env(
input_path=input_dir, intermediates_path=intermediate_dir, output_path=output_dir
)

@pytest.fixture()
def dot_xr(config) -> xr.Dataset:
return xr.open_dataset(config.dot_file)

@pytest.fixture(scope="session")
def fetch_published_domain(root_dir) -> list[pathlib.Path]:
"""
Fetch and cache the domain files.
@pytest.fixture()
def config() -> PriorConfig:
"""Default configuration"""
return load_config_from_env()
Don't use this fixture directly,
instead use `input_domain` to copy the files to the input directory.
Returns
-------
List of cached input files
"""
config = load_config_from_env()
published_domains = [
PublishedInputDomain(name="aust-test", version="v1.0.0"),
PublishedInputDomain(name="aust10km", version="v1.0.0"),
]

fragments = [domain.url_fragment() for domain in published_domains]

downloaded_files = download_input_files(
remote=config.remote,
download_path=root_dir / ".cache",
fragments=fragments,
)

return downloaded_files


# Fixture to download and later remove all input files
Expand All @@ -68,80 +92,81 @@ def fetch_input_files(root_dir) -> list[pathlib.Path]:
fragments = [str(f) for f in attrs.asdict(config.layer_inputs).values()]

downloaded_files = download_input_files(
download_path=root_dir / ".cache", fragments=fragments, remote=config.remote
remote=config.remote,
download_path=root_dir / ".cache",
fragments=fragments,
)

return downloaded_files


def copy_input_files(input_path: str | pathlib.Path, cached_files: list[pathlib.Path]):
def copy_input_files(
cache_path: pathlib.Path, input_path: str | pathlib.Path, cached_fragments: list[pathlib.Path]
):
"""
Copy input files from the cache into the input directory
Parameters
----------
cache_path
Path to the cache directory
input_path
Path to the input directory
cached_files
List of files that have been cached and should be copied into the input_path
"""
files = [input_path / cached_file.name for cached_file in cached_files]
cached_fragments
List of files that have been cached and should be copied into the input_path.
input_path.mkdir(parents=True, exist_ok=True)
Paths should be relative to the cache directory
"""
files = []
for fragment in cached_fragments:
input_file = pathlib.Path(input_path) / fragment
cached_file = cache_path / fragment
assert not input_file.exists()

for cached_file, input_file in zip(cached_files, files):
try:
os.remove(input_file)
except FileNotFoundError:
pass
input_file.parent.mkdir(exist_ok=True, parents=True)

shutil.copyfile(cached_file, input_file)

files.append(input_file)

yield files

for filepath in files:
if filepath.exists():
os.remove(filepath)
os.remove(filepath)


@pytest.fixture()
def input_files(root_dir, fetch_input_files, config) -> list[pathlib.Path]:
def input_files(root_dir, fetch_input_files, fetch_published_domain, config) -> list[pathlib.Path]:
"""
Ensure that the required input files are in the input directory.
The input files are copied from a local cache `.cache`
"""
yield from copy_input_files(config.input_path, fetch_input_files)
cache_dir = root_dir / ".cache"
files_to_copy = fetch_input_files + fetch_published_domain

fragments = [file.relative_to(cache_dir) for file in files_to_copy]
yield from copy_input_files(cache_dir, config.input_path, fragments)

@pytest.fixture(scope="session")
def input_domain(root_dir) -> xr.Dataset:

@pytest.fixture()
def input_domain(config, root_dir, input_files) -> xr.Dataset:
"""
Generate the input domain
Get an input domain
Returns
-------
The input domain as an xarray dataset
"""
config = load_config_from_env()

domain = create_domain_info(
geometry_file=config.geometry_file,
cross_file=config.cro_file,
dot_file=config.dot_file,
)
write_domain_info(domain, config.input_domain_file)

assert config.input_domain_file.exists()

yield domain

if config.input_domain_file.exists():
os.remove(config.input_domain_file)
yield config.domain_dataset()


@pytest.fixture(scope="session")
def output_domain(root_dir, input_domain, fetch_input_files, tmp_path_factory) -> xr.Dataset:
def output_domain(
root_dir, fetch_input_files, fetch_published_domain, tmp_path_factory
) -> xr.Dataset:
"""
Run the output domain
Expand All @@ -151,11 +176,27 @@ def output_domain(root_dir, input_domain, fetch_input_files, tmp_path_factory) -
"""
# Manually copy the input files to the input directory
# Can't use the config/input_files fixtures because we want to only run this step once
output_dir = tmp_path_factory.mktemp("data")
config = load_config_from_env(output_path=output_dir)
input_dir = tmp_path_factory.mktemp("inputs")
intermediate_dir = tmp_path_factory.mktemp("intermediates")
output_dir = tmp_path_factory.mktemp("outputs")

config = load_config_from_env(
input_path=input_dir,
intermediates_path=intermediate_dir,
output_path=output_dir,
# Use the test domain to speed things up
# input_domain=PublishedInputDomain(
# name="aust-test",
# version="v1.0.0",
# ),
)

# Use the factory method as input_files has "function" scope
input_files = next(copy_input_files(config.input_path, fetch_input_files))
cache_dir = root_dir / ".cache"
input_fragments = [
file.relative_to(cache_dir) for file in fetch_input_files + fetch_published_domain
]
input_files = next(copy_input_files(root_dir / ".cache", config.input_path, input_fragments))

run_prior(
config,
Expand All @@ -170,5 +211,4 @@ def output_domain(root_dir, input_domain, fetch_input_files, tmp_path_factory) -

# Manually clean up any leftover files
for filepath in input_files:
if filepath.exists():
os.remove(filepath)
os.remove(filepath)
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@
import json
from io import StringIO

import pytest

from scripts.omDomainJSON import write_domain_json


def test_001_json_structure(config, input_domain):
input_domain.to_netcdf(config.input_domain_file)

if not config.input_domain_file.exists():
pytest.mark.skip("Missing domain file")

outfile = StringIO()

# generate the JSON, writing to a memory buffer
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 386acd7

Please sign in to comment.