Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get_data_names and check functions #295

Merged
merged 7 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 51 additions & 50 deletions verde/base/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ class BaseGridder(BaseEstimator):
# using this name as a basis.
extra_coords_name = "extra_coord"

# Define default values for data_names depending on the number of data
# arrays returned by predict method.
data_names_defaults = [
("scalars",),
("east_component", "north_component"),
("east_component", "north_component", "vertical_component"),
]

def predict(self, coordinates):
"""
Predict data on the given coordinate values. NOT IMPLEMENTED.
Expand Down Expand Up @@ -416,8 +424,7 @@ def grid(
self.predict(project_coordinates(coordinates, projection))
)
# Get names for data and any extra coordinates
data_names = check_data_names(data_names)
data_names = get_data_names(data, data_names)
data_names = self._get_data_names(data, data_names)
extra_coords_names = self._get_extra_coords_names(coordinates)
# Create xarray.Dataset
dataset = make_xarray_grid(
Expand Down Expand Up @@ -506,8 +513,7 @@ def scatter(
data = check_data(
self.predict(project_coordinates(coordinates, projection))
)
data_names = check_data_names(data_names)
data_names = get_data_names(data, data_names)
data_names = self._get_data_names(data, data_names)
columns = [(dims[0], coordinates[1]), (dims[1], coordinates[0])]
extra_coords_names = self._get_extra_coords_names(coordinates)
columns.extend(zip(extra_coords_names, coordinates[2:]))
Expand Down Expand Up @@ -618,8 +624,7 @@ def profile(
# profile but Cartesian distances.
if projection is not None:
coordinates = project_coordinates(coordinates, projection, inverse=True)
data_names = check_data_names(data_names)
data_names = get_data_names(data, data_names)
data_names = self._get_data_names(data, data_names)
columns = [
(dims[0], coordinates[1]),
(dims[1], coordinates[0]),
Expand Down Expand Up @@ -665,6 +670,46 @@ def _get_extra_coords_names(self, coordinates):
names.append(name)
return names

def _get_data_names(self, data, data_names):
"""
Get default names for data fields if none are given based on the data.

Examples
--------

>>> import numpy as np
>>> east, north, up = [np.arange(10)]*3
>>> gridder = BaseGridder()
>>> gridder._get_data_names((east,), data_names=None)
('scalars',)
>>> gridder._get_data_names((east, north), data_names=None)
('east_component', 'north_component')
>>> gridder._get_data_names((east, north, up), data_names=None)
('east_component', 'north_component', 'vertical_component')
>>> gridder._get_data_names((east,), data_names="john")
('john',)
>>> gridder._get_data_names((east,), data_names=("paul",))
('paul',)
>>> gridder._get_data_names(
... (up, north), data_names=('ringo', 'george')
... )
('ringo', 'george')
>>> gridder._get_data_names((north,), data_names=["brian"])
['brian']

"""
# Return the defaults data_names for the class
if data_names is None:
if len(data) > len(self.data_names_defaults):
raise ValueError(
"Default data names only available for up to 3 components. "
+ "Must provide custom names through the 'data_names' argument."
)
return self.data_names_defaults[len(data) - 1]
# Return the passed data_names if valid
data_names = check_data_names(data, data_names)
return data_names


def project_coordinates(coordinates, projection, **kwargs):
"""
Expand Down Expand Up @@ -702,50 +747,6 @@ def project_coordinates(coordinates, projection, **kwargs):
return proj_coordinates


def get_data_names(data, data_names):
"""
Get default names for data fields if none are given based on the data.

Examples
--------

>>> import numpy as np
>>> east, north, up = [np.arange(10)]*3
>>> get_data_names((east,), data_names=None)
('scalars',)
>>> get_data_names((east, north), data_names=None)
('east_component', 'north_component')
>>> get_data_names((east, north, up), data_names=None)
('east_component', 'north_component', 'vertical_component')
>>> get_data_names((up, north), data_names=('ringo', 'george'))
('ringo', 'george')

"""
if data_names is not None:
if len(data) != len(data_names):
raise ValueError(
"Data has {} components but only {} names provided: {}".format(
len(data), len(data_names), str(data_names)
)
)
return data_names
data_types = [
("scalars",),
("east_component", "north_component"),
("east_component", "north_component", "vertical_component"),
]
if len(data) > len(data_types):
raise ValueError(
" ".join(
[
"Default data names only available for up to 3 components.",
"Must provide custom names through the 'data_names' argument.",
]
)
)
return data_types[len(data) - 1]


def get_instance_region(instance, region):
"""
Get the region attribute stored in instance if one is not provided.
Expand Down
72 changes: 62 additions & 10 deletions verde/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,39 @@ def check_data(data):
return data


def check_data_names(data_names):
def check_data_names(data, data_names):
"""
Check the *data_names* argument and make sure it's a tuple.
If ``data_names`` is a single string, return it as a tuple with a single
element.
Check *data_names* against *data*.

This is the default form accepted by gridders and functions that require
the ``data_names`` argument.
Also, convert ``data_names`` to a tuple if it's a single string.

Examples
--------

>>> check_data_names("dummy")
>>> import numpy as np
>>> east, north, scalar = [np.array(10)]*3
>>> check_data_names((scalar,), "dummy")
('dummy',)
>>> check_data_names(("component_x", "component_y"))
('component_x', 'component_y')
>>> check_data_names(["dummy"])
>>> check_data_names((scalar,), ("dummy",))
('dummy',)
>>> check_data_names((scalar,), ["dummy"])
['dummy']
>>> check_data_names((east, north), ("component_x", "component_y"))
('component_x', 'component_y')
"""
# Convert single string to tuple
if isinstance(data_names, str):
data_names = (data_names,)
# Raise error if data_names is None
if data_names is None:
raise ValueError("Invalid data_names equal to None.")
# Raise error if data and data_names don't have the same number of elements
if len(data) != len(data_names):
raise ValueError(
"Data has {} components but only {} names provided: {}".format(
len(data), len(data_names), str(data_names)
)
)
return data_names


Expand All @@ -143,6 +155,46 @@ def check_coordinates(coordinates):
return coordinates


def check_extra_coords_names(coordinates, extra_coords_names):
"""
Check extra_coords_names against coordiantes.

Also, convert ``extra_coords_names`` to a tuple if it's a single string.
Assume that there are extra coordinates on the ``coordinates`` tuple.

Examples
--------

>>> import numpy as np
>>> coordinates = [np.array(10)]*3
>>> check_extra_coords_names(coordinates, "upward")
('upward',)
>>> check_extra_coords_names(coordinates, ("upward",))
('upward',)
>>> coordinates = [np.array(10)]*4
>>> check_extra_coords_names(coordinates, ("upward", "time"))
('upward', 'time')
"""
# Convert single string to a tuple
if isinstance(extra_coords_names, str):
extra_coords_names = (extra_coords_names,)
# Check if it's not None
if extra_coords_names is None:
raise ValueError(
"Invalid extra_coords_names equal to None. "
+ "When passing one or more extra coordinate, "
+ "extra_coords_names cannot be None."
)
# Check if there are the same number of extra_coords than extra_coords_name
if len(coordinates[2:]) != len(extra_coords_names):
raise ValueError(
"Invalid extra_coords_names '{}'. ".format(extra_coords_names)
+ "Number of extra coordinates names must match the number of "
+ "additional coordinates ('{}').".format(len(coordinates[2:]))
)
return extra_coords_names


def check_fit_input(coordinates, data, weights, unpack=True):
"""
Validate the inputs to the fit method of gridders.
Expand Down
19 changes: 10 additions & 9 deletions verde/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ..base.base_classes import (
BaseGridder,
BaseBlockCrossValidator,
get_data_names,
get_instance_region,
)
from ..coordinates import grid_coordinates, scatter_points
Expand Down Expand Up @@ -46,28 +45,30 @@ def test_get_data_names():
data2 = tuple([np.arange(10)] * 2)
data3 = tuple([np.arange(10)] * 3)
# Test the default names
assert get_data_names(data1, data_names=None) == ("scalars",)
assert get_data_names(data2, data_names=None) == (
gridder = BaseGridder()
assert gridder._get_data_names(data1, data_names=None) == ("scalars",)
assert gridder._get_data_names(data2, data_names=None) == (
"east_component",
"north_component",
)
assert get_data_names(data3, data_names=None) == (
assert gridder._get_data_names(data3, data_names=None) == (
"east_component",
"north_component",
"vertical_component",
)
# Test custom names
assert get_data_names(data1, data_names=("a",)) == ("a",)
assert get_data_names(data2, data_names=("a", "b")) == ("a", "b")
assert get_data_names(data3, data_names=("a", "b", "c")) == ("a", "b", "c")
assert gridder._get_data_names(data1, data_names=("a",)) == ("a",)
assert gridder._get_data_names(data2, data_names=("a", "b")) == ("a", "b")
assert gridder._get_data_names(data3, data_names=("a", "b", "c")) == ("a", "b", "c")


def test_get_data_names_fails():
"Check if fails for invalid data types"
gridder = BaseGridder()
with pytest.raises(ValueError):
get_data_names(tuple([np.arange(5)] * 4), data_names=None)
gridder._get_data_names(tuple([np.arange(5)] * 4), data_names=None)
with pytest.raises(ValueError):
get_data_names(tuple([np.arange(5)] * 2), data_names=("meh",))
gridder._get_data_names(tuple([np.arange(5)] * 2), data_names=("meh",))


def test_get_instance_region():
Expand Down
13 changes: 8 additions & 5 deletions verde/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_partition_by_sum_fails_no_partitions():
assert "Could not find partition points" in str(error)


def test_build_grid():
def test_make_xarray_grid():
"""
Check if xarray.Dataset is correctly created
"""
Expand All @@ -121,7 +121,7 @@ def test_build_grid():
assert grid.dummy.shape == (5, 6)


def test_build_grid_multiple_data():
def test_make_xarray_grid_multiple_data():
"""
Check if xarray.Dataset with multiple data is correctly created
"""
Expand All @@ -138,7 +138,7 @@ def test_build_grid_multiple_data():
assert dataset["data_{}".format(i)].shape == (5, 6)


def test_build_grid_extra_coords():
def test_make_xarray_grid_extra_coords():
"""
Check if xarray.Dataset with extra coords is correctly created
"""
Expand All @@ -163,7 +163,7 @@ def test_build_grid_extra_coords():
assert dataset.time.shape == (5, 6)


def test_build_grid_invalid_names():
def test_make_xarray_grid_invalid_names():
"""
Check if errors are raise after invalid data names
"""
Expand All @@ -174,13 +174,16 @@ def test_build_grid_invalid_names():
data = np.ones_like(coordinates[0])
with pytest.raises(ValueError):
make_xarray_grid(coordinates, data, data_names=["bla_1", "bla_2"])
# data_names equal to None
with pytest.raises(ValueError):
make_xarray_grid(coordinates, data, data_names=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Even dims cannot be None!

# Multiple data, single data_name
data = tuple(i * np.ones_like(coordinates[0]) for i in (1, 2))
with pytest.raises(ValueError):
make_xarray_grid(coordinates, data, data_names="blabla")


def test_build_grid_invalid_extra_coords():
def test_make_xarray_grid_invalid_extra_coords():
"""
Check if errors are raise after invalid extra coords
"""
Expand Down
Loading