Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GeometryFM.contains (optional use of shapely) #700

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mikeio/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def __dataset_read_item_time_func(

def extract_track(
self,
track: pd.DataFrame,
track: pd.DataFrame | str,
method: Literal["nearest", "inverse_distance"] = "nearest",
dtype: Any = np.float32,
) -> "Dataset":
Expand All @@ -997,11 +997,10 @@ def extract_track(

Parameters
---------
track: pandas.DataFrame
track: pandas.DataFrame or str
with DatetimeIndex and (x, y) of track points as first two columns
x,y coordinates must be in same coordinate system as dfsu
track: str
filename of csv or dfs0 file containing t,x,y
or filename of csv or dfs0 file containing t,x,y
method: str, optional
Spatial interpolation method ('nearest' or 'inverse_distance')
default='nearest'
Expand Down
42 changes: 25 additions & 17 deletions mikeio/spatial/_FM_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def get_2d_interpolant(
elif n_nearest > 1:
weights = get_idw_interpolant(dists, p=p)
if not extrapolate:
weights[~self.contains(xy), :] = np.nan # type: ignore
weights[~self.contains(xy, strategy="shapely"), :] = np.nan # type: ignore
else:
ValueError("n_nearest must be at least 1")

Expand Down Expand Up @@ -696,7 +696,7 @@ def _find_element_2d(self, coords: np.ndarray) -> Any:
if not element_found and self.n_elements > 1:
many_nearest, _ = self._find_n_nearest_2d_elements(
coords[k, :],
n=min(self.n_elements, 10), # TODO is 10 enough?
n=min(self.n_elements, 100), # TODO is 10 enough?
)
for p in many_nearest[2:]: # we have already tried the two first above
nodes = self.element_table[p]
Expand All @@ -708,6 +708,8 @@ def _find_element_2d(self, coords: np.ndarray) -> Any:
break

if not element_found:
# make an extra check
# if not self.contains(coords[k]):
points_outside.append(k)

if len(points_outside) > 0:
Expand Down Expand Up @@ -843,7 +845,13 @@ def boundary_polylines(self) -> BoundaryPolylines:
"""Lists of closed polylines defining domain outline"""
return self._get_boundary_polylines()

def contains(self, points: np.ndarray) -> np.ndarray:
@cached_property
def _domain(self) -> Any:
return self.to_shapely().buffer(0)

def contains(
self, points: np.ndarray, strategy: Literal["loop", "shapely"] = "loop"
) -> np.ndarray:
"""test if a list of points are contained by mesh

Parameters
Expand All @@ -856,25 +864,25 @@ def contains(self, points: np.ndarray) -> np.ndarray:
bool array
True for points inside, False otherwise
"""
import matplotlib.path as mp # type: ignore

points = np.atleast_2d(points)

exterior = self.boundary_polylines.exteriors[0]
cnts = mp.Path(exterior.xy).contains_points(points)
if strategy == "shapely":
from shapely.geometry import Point # type: ignore

if self.boundary_polylines.n_exteriors > 1:
# in case of several dis-joint outer domains
for exterior in self.boundary_polylines.exteriors[1:]:
in_domain = mp.Path(exterior.xy).contains_points(points)
cnts = np.logical_or(cnts, in_domain)
domain = self._domain

# subtract any holes
for interior in self.boundary_polylines.interiors:
in_hole = mp.Path(interior.xy).contains_points(points)
cnts = np.logical_and(cnts, ~in_hole)
return np.array([domain.contains(Point(p)) for p in points])
else:
result = np.zeros(points.shape[0], dtype=bool)

for i, p in enumerate(points):
for elem in self.element_table:
coords = self.node_coordinates[elem]
if self._point_in_polygon(coords[:, 0], coords[:, 1], p[0], p[1]):
result[i] = True
break

return cnts
return result

def __contains__(self, pt: np.ndarray) -> bool:
return self.contains(pt)[0]
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"PyYAML",
"tqdm",
"xarray",
"shapely",
]

authors = [
Expand Down Expand Up @@ -49,7 +50,6 @@ dev = ["pytest",
"quartodoc==0.7.6",
"shapely",
"pyproj",
"xarray",
"netcdf4",
"rasterio",
"polars",
Expand All @@ -63,14 +63,14 @@ notebooks= [
"nbformat",
"nbconvert",
"jupyter",
"xarray",
"netcdf4",
"rasterio",
"geopandas",
"scikit-learn",
"matplotlib",
"folium",
"mapclassify",
"shapely",
]

[project.urls]
Expand Down
33 changes: 33 additions & 0 deletions tests/test_geometry_fm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import mikeio
from mikeio.spatial import GeometryFM2D, GeometryFM3D
from mikeio.exceptions import OutsideModelDomainError
from mikeio.spatial import GeometryPoint2D
Expand Down Expand Up @@ -238,6 +239,37 @@ def test_layered(simple_3d_geom: GeometryFM3D):
assert "layers: 2" in repr(g2)


def test_contains_complex_geometry():

msh = mikeio.open("tests/testdata/gulf.mesh")

points = [
[300_000, 3_200_000],
[400_000, 3_000_000],
[800_000, 2_750_000],
[1_200_000, 2_700_000],
]

res = msh.geometry.contains(points)

assert all(res)

res2 = msh.geometry.contains(points, strategy="shapely")
assert all(res2)


def test_find_index_in_highres_quad_area():

dfs = mikeio.open("tests/testdata/coastal_quad.dfsu")

pts = [(439166.047, 6921703.975), (439297.166, 6921728.645)]

idx = dfs.geometry.find_index(coords=pts)

assert len(idx) == 2
for i in idx:
assert i >= 0

def test_equality():
nc = [
(0.0, 0.0, 0.0), # 0
Expand Down Expand Up @@ -277,3 +309,4 @@ def test_equality_shifted_coords():

g2 = GeometryFM2D(node_coordinates=nc2, element_table=el, projection="LONG/LAT")
assert g != g2

Binary file added tests/testdata/coastal_quad.dfsu
Binary file not shown.
Loading
Loading