Skip to content

Commit

Permalink
add utility.xr_check_coords (#125)
Browse files Browse the repository at this point in the history
Co-authored-by: Markus Nagel <mark141@mi.fu-berlin.de>
Co-authored-by: Cagtay Fabry <43667554+CagtayFabry@users.noreply.github.com>
Co-authored-by: vhirtham <volker.hirthammer@bam.de>
  • Loading branch information
4 people authored Oct 9, 2020
1 parent a545f07 commit ca9cf39
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 6 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
- add basic schema layout and `GmawProcess` class for arc welding process implementation [#104]
- add example notebook and documentation for arc welding process [#104]
- fix propagating the `name` attribute when reading an ndarray `TimeSeries` object back from ASDF files [#104]
- fix `pint` regression in `TimeSeries` when mixing integer and float values
- fix `pint` regression in `TimeSeries` when mixing integer and float values [#121]
- add `pint` compatibility to some `geometry` classes (**experimental**)
- when passing quantities to constructors (and some functions), values get converted to default unit `mm` and passed on as magnitude
- old behavior is preserved
- add `weldx.utility.xr_check_coords` function to check coordinates of xarray object against dtype and value restrictions [#125]


## 0.2.0 (30.07.2020)
Expand Down
6 changes: 3 additions & 3 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@
"pandas": ("https://pandas.pydata.org/pandas-docs/stable", None),
"xarray": ("http://xarray.pydata.org/en/stable", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
"matplotlib": ("https://matplotlib.org", None),
"dask": ("https://docs.dask.org/en/latest", None),
"numba": ("https://numba.pydata.org/numba-doc/latest", None),
# "matplotlib": ("https://matplotlib.org", None),
# "dask": ("https://docs.dask.org/en/latest", None),
# "numba": ("https://numba.pydata.org/numba-doc/latest", None),
"pint": ("https://pint.readthedocs.io/en/stable", None),
}

Expand Down
68 changes: 68 additions & 0 deletions tests/test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,71 @@ def test_xf_fill_all():

with pytest.raises(ValueError):
ut.xr_fill_all(da3, order="wrong")


_dax_check = xr.DataArray(
data=np.ones((2, 2, 2, 4, 3)),
dims=["d1", "d2", "d3", "d4", "d5"],
coords={
"d1": np.array([-1, 1], dtype=float),
"d2": np.array([-1, 1], dtype=int),
"d3": pd.DatetimeIndex(["2020-05-01", "2020-05-03"]),
"d4": pd.TimedeltaIndex([0, 1, 2, 3], "s"),
"d5": ["x", "y", "z"],
},
)

_dax_ref = dict(
d1={"values": np.array([-1, 1]), "dtype": "float"},
d2={"values": np.array([-1, 1]), "dtype": int},
d3={
"values": pd.DatetimeIndex(["2020-05-01", "2020-05-03"]),
"dtype": ["datetime64[ns]", "timedelta64[ns]"],
},
d4={
"values": pd.TimedeltaIndex([0, 1, 2, 3], "s"),
"dtype": ["datetime64[ns]", "timedelta64[ns]"],
},
d5={"values": ["x", "y", "z"], "dtype": "<U1"},
)


@pytest.mark.parametrize(
"dax, ref_dict",
[
(_dax_check, _dax_ref),
(_dax_check.coords, _dax_ref),
(_dax_check, {"d1": {"dtype": ["float64", int]}}),
(_dax_check, {"d2": {"dtype": ["float64", int]}}),
(_dax_check, {"no_dim": {"optional": True, "dtype": float}}),
(_dax_check, {"d5": {"dtype": str}}),
(_dax_check, {"d5": {"dtype": [str]}}),
(_dax_check, {"d4": {"dtype": "timedelta64"}}),
(_dax_check, {"d3": {"dtype": ["datetime64", "timedelta64"]}}),
],
)
def test_xr_check_coords(dax, ref_dict):
"""Test weldx.utility.xr_check_coords function."""
assert ut.xr_check_coords(dax, ref_dict)


@pytest.mark.parametrize(
"dax, ref_dict, exception_type",
[
(_dax_check, {"d1": {"dtype": int}}, TypeError),
(_dax_check, {"d1": {"dtype": int, "optional": True}}, TypeError),
(_dax_check, {"no_dim": {"dtype": float}}, KeyError),
(
_dax_check,
{"d5": {"values": ["x", "noty", "z"], "dtype": "str"}},
ValueError,
),
(_dax_check, {"d1": {"dtype": [int, str, bool]}}, TypeError),
(_dax_check, {"d4": {"dtype": "datetime64"}}, TypeError),
({"d4": np.arange(4)}, {"d4": {"dtype": "int"}}, ValueError),
],
)
def test_xr_check_coords_exception(dax, ref_dict, exception_type):
"""Test weldx.utility.xr_check_coords function."""
with pytest.raises(exception_type):
ut.xr_check_coords(dax, ref_dict)
19 changes: 17 additions & 2 deletions weldx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,23 @@ def __init__(
coordinates = self._build_coordinates(coordinates, time)

if construction_checks:
ut.xr_check_coords(
coordinates,
dict(
c={"values": ["x", "y", "z"]},
time={"dtype": "timedelta64", "optional": True},
),
)

ut.xr_check_coords(
orientation,
dict(
c={"values": ["x", "y", "z"]},
v={"values": [0, 1, 2]},
time={"dtype": "timedelta64", "optional": True},
),
)

orientation = xr.apply_ufunc(
normalize,
orientation,
Expand Down Expand Up @@ -636,7 +653,6 @@ def _build_orientation(
"""
if isinstance(orientation, xr.DataArray):
return orientation
# TODO: Test if xarray has correct format

time_orientation = None
if isinstance(orientation, Rot):
Expand Down Expand Up @@ -667,7 +683,6 @@ def _build_coordinates(coordinates, time: pd.DatetimeIndex = None):
"""
if isinstance(coordinates, xr.DataArray):
return coordinates
# TODO: Test if xarray has correct format

time_coordinates = None
if not isinstance(coordinates, (np.ndarray, pint.Quantity)):
Expand Down
131 changes: 131 additions & 0 deletions weldx/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,137 @@ def xr_interp_like(
return result


def _check_dtype(var_dtype, ref_dtype: dict) -> bool:
"""Check if dtype matches a reference dtype (or is subdtype).
Parameters
----------
var_dtype : numpy dtype
A numpy-dtype to test against.
ref_dtype : dict
Python type or string description
Returns
-------
bool
True if dtypes matches.
"""
if var_dtype != np.dtype(ref_dtype):
if isinstance(ref_dtype, str):
if (
"timedelta64" in ref_dtype
or "datetime64" in ref_dtype
and np.issubdtype(var_dtype, np.dtype(ref_dtype))
):
return True

if not (
np.issubdtype(var_dtype, np.dtype(ref_dtype)) and np.dtype(ref_dtype) == str
):
return False

return True


def xr_check_coords(dax: xr.DataArray, ref: dict) -> bool:
"""Validate the coordinates of the DataArray against a reference dictionary.
The reference dictionary should have the dimensions as keys and those contain
dictionaries with the following keywords (all optional):
``values``
Specify exact coordinate values to match.
``dtype`` : str or type
Ensure coordinate dtype matches at least one of the given dtypes.
``optional`` : boolean
default ``False`` - if ``True``, the dimension has to be in the DataArray dax
Parameters
----------
dax : xarray.DataArray
xarray object which should be validated
ref : dict
reference dictionary
Returns
-------
bool
True, if the test was a success, else an exception is raised
Examples
--------
>>> import pandas as pd
>>> import xarray as xr
>>> import weldx as wx
>>> dax = xr.DataArray(
... data=np.ones((3, 2, 3)),
... dims=["d1", "d2", "d3"],
... coords={
... "d1": np.array([-1, 0, 2], dtype=int),
... "d2": pd.DatetimeIndex(["2020-05-01", "2020-05-03"]),
... "d3": ["x", "y", "z"],
... }
... )
>>> ref = dict(
... d1={"optional": True, "values": np.array([-1, 0, 2], dtype=int)},
... d2={
... "values": pd.DatetimeIndex(["2020-05-01", "2020-05-03"]),
... "dtype": ["datetime64[ns]", "timedelta64[ns]"],
... },
... d3={"values": ["x", "y", "z"], "dtype": "<U1"},
... )
>>> wx.utility.xr_check_coords(dax, ref)
True
"""
# only process the coords of the xarray
if isinstance(dax, (xr.DataArray, xr.Dataset)):
coords = dax.coords
elif isinstance(
dax,
(
xr.core.coordinates.DataArrayCoordinates,
xr.core.coordinates.DatasetCoordinates,
),
):
coords = dax
else:
raise ValueError("Input variable is not an xarray object")

for key, check in ref.items():
# check if the optional key is set to true
if "optional" in check:
if check["optional"] and key not in coords:
# skip this key - it is not in dax
continue

if key not in coords:
# Attributes not found in coords
raise KeyError(f"Could not find required coordinate '{key}'.")

# only if the key "values" is given do the validation
if "values" in check:
if not (coords[key].values == check["values"]).all():
raise ValueError(f"Value mismatch in DataArray and ref['{key}']")

# only if the key "dtype" is given do the validation
if "dtype" in check:
dtype_list = check["dtype"]
if not isinstance(dtype_list, list):
dtype_list = [dtype_list]
if not any(
_check_dtype(coords[key].dtype, var_dtype) for var_dtype in dtype_list
):
raise TypeError(
f"Mismatch in the dtype of the DataArray and ref['{key}']"
)

return True


def xr_3d_vector(data, times=None) -> xr.DataArray:
"""Create an xarray 3d vector with correctly named dimensions and coordinates.
Expand Down

0 comments on commit ca9cf39

Please sign in to comment.