Skip to content

Commit

Permalink
fixing scripts tests, rearranging with contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
paverett committed Mar 25, 2024
1 parent 6d9d6ef commit d24ec43
Show file tree
Hide file tree
Showing 18 changed files with 91 additions and 125 deletions.
8 changes: 4 additions & 4 deletions astropy/convolution/tests/test_convolve_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,12 +787,12 @@ def test_big_fail(self):
Test that convolve_fft raises an exception if a too-large array is passed in.
"""

# note 512**3 * 16 bytes = 2.0 GB
# while a good idea, this approach did not work; it actually writes to disk
# arr = np.memmap('file.np', mode='w+', shape=(512, 512, 512), dtype=complex)
# this just allocates the memory but never touches it; it's better:
arr = np.empty([512, 512, 512], dtype=complex)
# note 512**3 * 16 bytes = 2.0 GB
with pytest.raises((ValueError, MemoryError)):
# while a good idea, this approach did not work; it actually writes to disk
# arr = np.memmap('file.np', mode='w+', shape=(512, 512, 512), dtype=complex)
# this just allocates the memory but never touches it; it's better:
convolve_fft(arr, arr)

def test_padding(self):
Expand Down
34 changes: 18 additions & 16 deletions astropy/convolution/tests/test_kernel_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,7 @@
AiryDisk2DKernel,
Ring2DKernel,
]
KERNEL_TYPES_WITH_ARGS = [
(AiryDisk2DKernel, (2,)),
(Box2DKernel, (2,)),
(Gaussian1DKernel, (2,)),
(Gaussian2DKernel, (2,)),
(RickerWavelet1DKernel, (2,)),
(RickerWavelet2DKernel, (2,)),
(Model1DKernel, (Gaussian1D(1, 0, 2),)),
(Model2DKernel, (Gaussian2D(1, 0, 0, 2, 2),)),
(Ring2DKernel, (9, 8)),
(Tophat2DKernel, (2,)),
(Trapezoid1DKernel, (2,)),
(Trapezoid1DKernel, (2,)),
]


NUMS = [1, 1.0, np.float32(1.0), np.float64(1.0)]

Expand Down Expand Up @@ -635,11 +622,26 @@ def test_kernel2d_initialization(self):
with pytest.raises(TypeError):
Kernel2D()

@pytest.mark.parametrize(["kernel", "opt"], KERNEL_TYPES_WITH_ARGS)
@pytest.mark.parametrize(
["kernel", "opt"],
[
(AiryDisk2DKernel, (2,)),
(Box2DKernel, (2,)),
(Gaussian1DKernel, (2,)),
(Gaussian2DKernel, (2,)),
(RickerWavelet1DKernel, (2,)),
(RickerWavelet2DKernel, (2,)),
(Model1DKernel, (Gaussian1D(1, 0, 2),)),
(Model2DKernel, (Gaussian2D(1, 0, 0, 2, 2),)),
(Ring2DKernel, (9, 8)),
(Tophat2DKernel, (2,)),
(Trapezoid1DKernel, (2,)),
],
)
def test_array_keyword_not_allowed(self, kernel, opt):
"""
Regression test for issue #10439
"""
x = np.ones([10, 10])
with pytest.raises(TypeError, match=r".* allowed .*"):
with pytest.raises(TypeError, match=r"Array argument not allowed for kernel.*"):
kernel(*opt, array=x)
1 change: 0 additions & 1 deletion astropy/coordinates/tests/test_angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,6 @@ def test_latitude():
):
Latitude(lon)

lon = Longitude(10, "deg")
lat = Latitude([20], "deg")
with pytest.raises(
TypeError, match="A Longitude angle cannot be assigned to a Latitude angle"
Expand Down
12 changes: 12 additions & 0 deletions astropy/cosmology/flrw/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@

from __future__ import annotations

import copy
from typing import TYPE_CHECKING, TypeVar

import pytest

from astropy.cosmology import core
from astropy.cosmology.tests.helper import clean_registry # noqa: F401
from astropy.tests.helper import pickle_protocol # noqa: F401

Expand Down Expand Up @@ -34,3 +38,11 @@ def filter_keys_from_items(
Iterable of ``(key, value)`` pairs with the ``filter_out`` keys removed.
"""
return ((k, v) for k, v in m.items() if k not in filter_out)


@pytest.fixture
def clean_cosmology_classes():
original = copy.deepcopy(core._COSMOLOGY_CLASSES)
yield
core._COSMOLOGY_CLASSES.clear()
core._COSMOLOGY_CLASSES.update(original)
6 changes: 1 addition & 5 deletions astropy/cosmology/flrw/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,7 @@ def setup_class(self):
# ---------------------------------------------------------------
# class-level

@pytest.mark.usefixtures("clean_cosmology_classes")
def test_init_subclass(self, cosmo_cls):
"""Test initializing subclass, mostly that can't have Ode0 in init."""
super().test_init_subclass(cosmo_cls)
Expand All @@ -1019,11 +1020,6 @@ class HASOde0SubClass(cosmo_cls):
def __init__(self, Ode0):
pass

try:
_COSMOLOGY_CLASSES.pop(HASOde0SubClass.__qualname__, None)
except UnboundLocalError:
pass

# ---------------------------------------------------------------
# instance-level

Expand Down
15 changes: 3 additions & 12 deletions astropy/io/ascii/tests/test_c_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,25 +866,16 @@ def test_rdb(read_rdb):
assert table["B"].dtype.kind in ("S", "U")
assert_equal(table["C"].dtype.kind, "f")

text = "A\tB\tC\nN\tS\tN\n4\tb\ta" # C column contains non-numeric data
with pytest.raises(ValueError) as e:
read_rdb(
text,
)
read_rdb("A\tB\tC\nN\tS\tN\n4\tb\ta") # C column contains non-numeric data
assert "Column C failed to convert" in str(e.value)

text = "A\tB\tC\nN\tN\n1\t2\t3" # not enough types specified
with pytest.raises(ValueError) as e:
read_rdb(
text,
)
read_rdb("A\tB\tC\nN\tN\n1\t2\t3") # not enough types specified
assert "mismatch between number of column names and column types" in str(e.value)

text = "A\tB\tC\nN\tN\t5\n1\t2\t3" # invalid type for column C
with pytest.raises(ValueError) as e:
read_rdb(
text,
)
read_rdb("A\tB\tC\nN\tN\t5\n1\t2\t3") # invalid type for column C
assert "type definitions do not all match [num](N|S)" in str(e.value)


Expand Down
3 changes: 1 addition & 2 deletions astropy/io/fits/scripts/fitscheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def handle_options(args):
args = ["-h"]

parser = argparse.ArgumentParser(
description=DESCRIPTION,
formatter_class=argparse.RawDescriptionHelpFormatter,
description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter
)

parser.add_argument(
Expand Down
3 changes: 1 addition & 2 deletions astropy/io/fits/scripts/fitsheader.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,7 @@ def print_headers_as_comparison(args):
def main(args=None):
"""This is the main function called by the `fitsheader` script."""
parser = argparse.ArgumentParser(
description=DESCRIPTION,
formatter_class=argparse.RawDescriptionHelpFormatter,
description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter
)

parser.add_argument(
Expand Down
3 changes: 1 addition & 2 deletions astropy/io/fits/scripts/fitsinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def fitsinfo(filename):
def main(args=None):
"""The main function called by the `fitsinfo` script."""
parser = argparse.ArgumentParser(
description=DESCRIPTION,
formatter_class=argparse.RawDescriptionHelpFormatter,
description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--version", action="version", version=f"%(prog)s {__version__}"
Expand Down
61 changes: 24 additions & 37 deletions astropy/io/fits/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,30 +643,24 @@ def test_open_file_handle(self):
pass

# Opening without explicitly specifying binary mode should fail
with (
pytest.raises(ValueError),
open(self.data("test0.fits")) as handle,
fits.open(handle) as _,
):
pass
with open(self.data("test0.fits")) as handle:
with pytest.raises(ValueError):
with fits.open(handle) as _:
pass

# All of these read modes should fail
for mode in ["r", "rt"]:
with (
pytest.raises(ValueError),
open(self.data("test0.fits"), mode=mode) as handle,
fits.open(handle) as _,
):
pass
with open(self.data("test0.fits"), mode=mode) as handle:
with pytest.raises(ValueError):
with fits.open(handle) as _:
pass

# These update or write modes should fail as well
for mode in ["w", "wt", "w+", "wt+", "r+", "rt+", "a", "at", "a+", "at+"]:
with (
pytest.raises(ValueError),
open(self.temp("temp.fits"), mode=mode) as handle,
fits.open(handle) as _,
):
pass
with open(self.temp("temp.fits"), mode=mode) as handle:
with pytest.raises(ValueError):
with fits.open(handle) as _:
pass

def test_fits_file_handle_mode_combo(self):
# This should work fine since no mode is given
Expand All @@ -680,12 +674,10 @@ def test_fits_file_handle_mode_combo(self):
pass

# This should not work since the modes conflict
with (
pytest.raises(ValueError),
open(self.data("test0.fits"), "rb") as handle,
fits.open(handle, mode="ostream") as _,
):
pass
with open(self.data("test0.fits"), "rb") as handle:
with pytest.raises(ValueError):
with fits.open(handle, mode="ostream") as _:
pass

def test_open_from_url(self):
file_url = "file:///" + self.data("test0.fits").lstrip("/")
Expand All @@ -695,20 +687,13 @@ def test_open_from_url(self):

# It will not be possible to write to a file that is from a URL object
for mode in ("ostream", "append", "update"):
with (
pytest.raises(ValueError),
urllib.request.urlopen(file_url) as urlobj,
fits.open(urlobj, mode=mode) as _,
):
pass
with urllib.request.urlopen(file_url) as urlobj:
with pytest.raises(ValueError):
with fits.open(urlobj, mode=mode) as _:
pass

@pytest.mark.remote_data(source="astropy")
def test_open_from_remote_url(self):
def open_from_remote_url(remote_url, mode):
with urllib.request.urlopen(remote_url) as urlobj:
with fits.open(urlobj, mode=mode) as fits_handle:
assert len(fits_handle) == 1

for dataurl in (conf.dataurl, conf.dataurl_mirror):
remote_url = f"{dataurl}/allsky/allsky_rosat.fits"
try:
Expand All @@ -717,8 +702,10 @@ def open_from_remote_url(remote_url, mode):
assert len(fits_handle) == 1

for mode in ("ostream", "append", "update"):
with pytest.raises(ValueError):
open_from_remote_url(remote_url, mode)
with urllib.request.urlopen(remote_url) as urlobj:
with pytest.raises(ValueError):
with fits.open(urlobj, mode=mode) as fits_handle:
assert len(fits_handle) == 1

except (urllib.error.HTTPError, urllib.error.URLError):
continue
Expand Down
11 changes: 4 additions & 7 deletions astropy/io/fits/tests/test_fitscheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@ def test_help(self):

def test_version(self, capsys):
script = "fitscheck"
p = subprocess.Popen(
[script, "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
script_result = subprocess.run(
[script, "--version"], capture_output=True, text=True, check=False
)
stdout, stderr = p.communicate()

assert p.returncode == 0
assert stdout.decode("utf-8").strip() == f"fitscheck {version}"
assert script_result.returncode == 0
assert script_result.stdout.strip() == f"{script} {version}"

def test_missing_file(self, capsys):
assert fitscheck.main(["missing.fits"]) == 1
Expand Down
11 changes: 4 additions & 7 deletions astropy/io/fits/tests/test_fitsdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,12 @@ def test_help(self):

def test_version(self, capsys):
script = "fitsdiff"
p = subprocess.Popen(
[script, "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
script_result = subprocess.run(
[script, "--version"], capture_output=True, text=True, check=False
)
stdout, stderr = p.communicate()

assert p.returncode == 0
assert stdout.decode("utf-8").strip() == f"fitsdiff {version}"
assert script_result.returncode == 0
assert script_result.stdout.strip() == f"{script} {version}"

def test_noargs(self):
with pytest.raises(SystemExit) as e:
Expand Down
11 changes: 4 additions & 7 deletions astropy/io/fits/tests/test_fitsheader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@ def test_help(self):

def test_version(self, capsys):
script = "fitsheader"
p = subprocess.Popen(
[script, "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
script_result = subprocess.run(
[script, "--version"], capture_output=True, text=True, check=False
)
stdout, stderr = p.communicate()

assert p.returncode == 0
assert stdout.decode("utf-8").strip() == f"fitsheader {version}"
assert script_result.returncode == 0
assert script_result.stdout.strip() == f"{script} {version}"

def test_file_exists(self, capsys):
fitsheader.main([self.data("arange.fits")])
Expand Down
11 changes: 4 additions & 7 deletions astropy/io/fits/tests/test_fitsinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@ def test_help(self):

def test_version(self, capsys):
script = "fitsinfo"
p = subprocess.Popen(
[script, "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
script_result = subprocess.run(
[script, "--version"], capture_output=True, text=True, check=False
)
stdout, stderr = p.communicate()

assert p.returncode == 0
assert stdout.decode("utf-8").strip() == f"fitsinfo {version}"
assert script_result.returncode == 0
assert script_result.stdout.strip() == f"{script} {version}"

def test_onefile(self, capsys):
fitsinfo.main([self.data("arange.fits")])
Expand Down
5 changes: 3 additions & 2 deletions astropy/io/fits/tests/test_fitstime.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ def test_time_to_fits_loc(self, table_types):
self.time, format="isot", scale="tt", location=EarthLocation(2, 3, 4)
)

with pytest.raises(ValueError) as err:
with pytest.raises(
ValueError, match="Multiple Time Columns with different geocentric"
):
table, hdr = time_to_fits(t)
assert "Multiple Time Columns with different geocentric" in str(err.value)

# Check that Time column with no location specified will assume global location
t["b"] = Time(self.time, format="isot", scale="tt", location=None)
Expand Down
3 changes: 1 addition & 2 deletions astropy/uncertainty/tests/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ def test_reprs():
],
)
def test_wrong_kw_fails(func, kws):
kw_temp = kws.copy()
kw_temp["n_sample"] = 100 # note the missing "s"
kw_temp = kws | {"n_sample": 100} # note the missing "s"
with pytest.raises(TypeError, match="missing 1 required"):
assert func(**kw_temp).n_samples == 100
kw_temp = kws.copy()
Expand Down
Loading

0 comments on commit d24ec43

Please sign in to comment.